From 1c7bb732e157a03f65546fec1eba66b22ea2261a Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:14:23 -0500 Subject: [PATCH 01/19] feat: add TDX quote parsing (abi_tdx.py) Implement QuoteV4 binary parser for Intel TDX attestation quotes. Includes header, TD quote body, QE report, signed data, and PCK certificate chain parsing with named constants for all offsets and sizes per the Intel TDX DCAP specification. Validates header fields (version, AK type, TEE type, QE vendor ID) and rejects unsupported QuoteV5 and certification data type 5. --- src/tinfoil/attestation/abi_tdx.py | 758 +++++++++++++++++++++++++++++ tests/test_tdx_abi.py | 548 +++++++++++++++++++++ 2 files changed, 1306 insertions(+) create mode 100644 src/tinfoil/attestation/abi_tdx.py create mode 100644 tests/test_tdx_abi.py diff --git a/src/tinfoil/attestation/abi_tdx.py b/src/tinfoil/attestation/abi_tdx.py new file mode 100644 index 0000000..e91482e --- /dev/null +++ b/src/tinfoil/attestation/abi_tdx.py @@ -0,0 +1,758 @@ +""" +TDX Quote parsing structures and constants. + +This module provides data structures and parsing logic for Intel TDX attestation +quotes in the QuoteV4 format. +""" + +import struct +from dataclasses import dataclass +from typing import List + +# ============================================================================= +# Constants +# ============================================================================= + +# Quote structure sizes +QUOTE_MIN_SIZE = 0x3FC # 1020 bytes minimum +HEADER_SIZE = 0x30 # 48 bytes +TD_QUOTE_BODY_SIZE = 0x248 # 584 bytes +QE_REPORT_SIZE = 0x180 # 384 bytes + +# Quote versions +QUOTE_VERSION_V4 = 4 +QUOTE_VERSION_V5 = 5 + +# TEE type for TDX +TEE_TDX = 0x00000081 + +# Attestation key type (ECDSA-256-with-P-256 curve) +ATTESTATION_KEY_TYPE_ECDSA_P256 = 2 + +# Certification data types +CERT_DATA_TYPE_PCK_CERT_CHAIN = 5 +CERT_DATA_TYPE_QE_REPORT = 6 + +# Field sizes +TEE_TCB_SVN_SIZE = 0x10 # 16 bytes +MR_SEAM_SIZE = 0x30 # 48 bytes +MR_SIGNER_SEAM_SIZE = 0x30 # 48 bytes +SEAM_ATTRIBUTES_SIZE = 0x08 # 8 bytes +TD_ATTRIBUTES_SIZE = 0x08 # 8 bytes +XFAM_SIZE = 0x08 # 8 bytes +MR_TD_SIZE = 0x30 # 48 bytes +MR_CONFIG_ID_SIZE = 0x30 # 48 bytes +MR_OWNER_SIZE = 0x30 # 48 bytes +MR_OWNER_CONFIG_SIZE = 0x30 # 48 bytes +RTMR_SIZE = 0x30 # 48 bytes +RTMR_COUNT = 4 +REPORT_DATA_SIZE = 0x40 # 64 bytes +QE_VENDOR_ID_SIZE = 0x10 # 16 bytes +USER_DATA_SIZE = 0x14 # 20 bytes +SIGNATURE_SIZE = 0x40 # 64 bytes +ATTESTATION_KEY_SIZE = 0x40 # 64 bytes +ECDSA_P256_COMPONENT_SIZE = 0x20 # 32 bytes per R or S component +SHA256_HASH_SIZE = 0x20 # 32 bytes +CERT_DATA_HEADER_SIZE = 6 # 2 bytes type + 4 bytes size +PCK_CERT_CHAIN_COUNT = 3 # leaf, intermediate, root + +# Intel QE Vendor ID: 939a7233-f79c-4ca9-940a-0db3957f0607 +INTEL_QE_VENDOR_ID = bytes.fromhex("939a7233f79c4ca9940a0db3957f0607") + +# ============================================================================= +# Header offsets (relative to quote start) +# ============================================================================= + +HEADER_VERSION_START = 0x00 +HEADER_VERSION_END = 0x02 +HEADER_AK_TYPE_START = 0x02 +HEADER_AK_TYPE_END = 0x04 +HEADER_TEE_TYPE_START = 0x04 +HEADER_TEE_TYPE_END = 0x08 +# Bytes 0x08-0x0C are reserved in QuoteV4. +# Note: Some older specs labeled these as QE_SVN/PCE_SVN, but they are +# always zero in practice. The actual SVN values come from: +# - PCE SVN: PCK certificate extensions (OID 1.2.840.113741.1.13.1.2.17) +# - QE ISV SVN: QE Report at offset 0x102 within certification data +HEADER_RESERVED1_START = 0x08 +HEADER_RESERVED1_END = 0x0C +HEADER_QE_VENDOR_ID_START = 0x0C +HEADER_QE_VENDOR_ID_END = 0x1C +HEADER_USER_DATA_START = 0x1C +HEADER_USER_DATA_END = 0x30 + +# ============================================================================= +# TdQuoteBody offsets (relative to body start at 0x30) +# ============================================================================= + +TD_TEE_TCB_SVN_START = 0x00 +TD_TEE_TCB_SVN_END = 0x10 +TD_MR_SEAM_START = 0x10 +TD_MR_SEAM_END = 0x40 +TD_MR_SIGNER_SEAM_START = 0x40 +TD_MR_SIGNER_SEAM_END = 0x70 +TD_SEAM_ATTRIBUTES_START = 0x70 +TD_SEAM_ATTRIBUTES_END = 0x78 +TD_ATTRIBUTES_START = 0x78 +TD_ATTRIBUTES_END = 0x80 +TD_XFAM_START = 0x80 +TD_XFAM_END = 0x88 +TD_MR_TD_START = 0x88 +TD_MR_TD_END = 0xB8 +TD_MR_CONFIG_ID_START = 0xB8 +TD_MR_CONFIG_ID_END = 0xE8 +TD_MR_OWNER_START = 0xE8 +TD_MR_OWNER_END = 0x118 +TD_MR_OWNER_CONFIG_START = 0x118 +TD_MR_OWNER_CONFIG_END = 0x148 +TD_RTMRS_START = 0x148 +TD_RTMRS_END = 0x208 +TD_REPORT_DATA_START = 0x208 +TD_REPORT_DATA_END = 0x248 + +# ============================================================================= +# Quote-level offsets +# ============================================================================= + +QUOTE_HEADER_START = 0x00 +QUOTE_HEADER_END = 0x30 +QUOTE_BODY_START = 0x30 +QUOTE_BODY_END = 0x278 +QUOTE_SIGNED_DATA_SIZE_START = 0x278 +QUOTE_SIGNED_DATA_SIZE_END = 0x27C +QUOTE_SIGNED_DATA_START = 0x27C + +# ============================================================================= +# SignedData offsets (relative to signed data start) +# ============================================================================= + +SIGNED_DATA_SIGNATURE_START = 0x00 +SIGNED_DATA_SIGNATURE_END = 0x40 +SIGNED_DATA_AK_START = 0x40 +SIGNED_DATA_AK_END = 0x80 +SIGNED_DATA_CERT_DATA_START = 0x80 + +# ============================================================================= +# QE Report offsets (within certification data) +# ============================================================================= + +QE_CPU_SVN_START = 0x00 +QE_CPU_SVN_END = 0x10 +QE_MISC_SELECT_START = 0x10 +QE_MISC_SELECT_END = 0x14 +QE_ATTRIBUTES_START = 0x30 +QE_ATTRIBUTES_END = 0x40 +QE_MR_ENCLAVE_START = 0x40 +QE_MR_ENCLAVE_END = 0x60 +QE_MR_SIGNER_START = 0x80 +QE_MR_SIGNER_END = 0xA0 +QE_ISV_PROD_ID_START = 0x100 +QE_ISV_PROD_ID_END = 0x102 +QE_ISV_SVN_START = 0x102 +QE_ISV_SVN_END = 0x104 +QE_REPORT_DATA_START = 0x140 +QE_REPORT_DATA_END = 0x180 + + +# ============================================================================= +# Data Structures +# ============================================================================= + +@dataclass +class TdxHeader: + """ + TDX Quote header (48 bytes). + + Contains quote metadata including version, attestation key type, + TEE type, and vendor information. + + Note: Bytes 8-11 are reserved. Some older specs labeled these as + QE_SVN/PCE_SVN, but they are always zero in practice. The actual + SVN values come from PCK certificate extensions and QE Report. + """ + version: int # 2 bytes - must be 4 for QuoteV4 + attestation_key_type: int # 2 bytes - must be 2 (ECDSA-P256) + tee_type: int # 4 bytes - must be 0x81 (TDX) + reserved: bytes # 4 bytes - reserved (was QE_SVN/PCE_SVN in older specs) + qe_vendor_id: bytes # 16 bytes - Intel: 939a7233-f79c-4ca9-940a-0db3957f0607 + user_data: bytes # 20 bytes - Custom data from QE + + def __str__(self) -> str: + return ( + f"TdxHeader(version={self.version}, " + f"ak_type={self.attestation_key_type}, " + f"tee_type=0x{self.tee_type:x}, " + f"qe_vendor_id={self.qe_vendor_id.hex()})" + ) + + +@dataclass +class TdQuoteBody: + """ + TD Quote Body (584 bytes). + + Contains the TD's measurements and report data. This is the core + attestation data signed by the QE. + """ + tee_tcb_svn: bytes # 16 bytes - TEE TCB Security Version Number + mr_seam: bytes # 48 bytes - Measurement of SEAM module + mr_signer_seam: bytes # 48 bytes - Signer of SEAM module (zeros for Intel SEAM) + seam_attributes: bytes # 8 bytes - SEAM attributes + td_attributes: bytes # 8 bytes - TD attributes + xfam: bytes # 8 bytes - Extended feature mask + mr_td: bytes # 48 bytes - Measurement of TD (MRTD) + mr_config_id: bytes # 48 bytes - Config ID + mr_owner: bytes # 48 bytes - Owner measurement + mr_owner_config: bytes # 48 bytes - Owner config measurement + rtmrs: List[bytes] # 4 x 48 bytes - Runtime measurement registers + report_data: bytes # 64 bytes - Custom data (TLS key FP + HPKE key) + + def __str__(self) -> str: + if len(self.rtmrs) == RTMR_COUNT: + rtmr_lines = "".join( + f" rtmr{i}={self.rtmrs[i].hex()},\n" for i in range(RTMR_COUNT) + ) + else: + rtmr_lines = f" rtmrs=({len(self.rtmrs)} entries),\n" + return ( + f"TdQuoteBody(\n" + f" mr_td={self.mr_td.hex()},\n" + f"{rtmr_lines}" + f" mr_seam={self.mr_seam.hex()},\n" + f" report_data={self.report_data.hex()}\n" + f")" + ) + + def get_measurements(self) -> List[bytes]: + """Return the 5 TDX measurements: [MRTD, RTMR0, RTMR1, RTMR2, RTMR3].""" + return [self.mr_td] + self.rtmrs + + +@dataclass +class QeReport: + """ + Quoting Enclave Report (384 bytes). + + SGX enclave report from the Quoting Enclave, used to verify + the attestation key is bound to a legitimate QE. + """ + cpu_svn: bytes # 16 bytes + misc_select: int # 4 bytes + attributes: bytes # 16 bytes + mr_enclave: bytes # 32 bytes - QE enclave measurement + mr_signer: bytes # 32 bytes - QE signer measurement + isv_prod_id: int # 2 bytes - Product ID + isv_svn: int # 2 bytes - Security Version Number + report_data: bytes # 64 bytes - Contains hash of attestation key + + def __str__(self) -> str: + return ( + f"QeReport(mr_enclave={self.mr_enclave.hex()}, " + f"mr_signer={self.mr_signer.hex()}, " + f"isv_prod_id={self.isv_prod_id}, " + f"isv_svn={self.isv_svn})" + ) + + +@dataclass +class QeReportCertificationData: + """ + QE Report Certification Data. + + Contains the QE report, its signature, authentication data, + and the nested PCK certificate chain. + """ + qe_report: bytes # 384 bytes - Raw QE report + qe_report_parsed: QeReport # Parsed QE report + qe_report_signature: bytes # 64 bytes - ECDSA signature over QE report + qe_auth_data: bytes # Variable - Authentication data + pck_cert_chain_data: "PckCertChainData" # Nested PCK certificate chain + + +@dataclass +class PckCertChainData: + """ + PCK Certificate Chain Data. + + Contains the certification data type and the actual certificate + chain in PEM format. + """ + cert_type: int # 2 bytes - Should be 5 (PCK cert chain) + cert_data_size: int # 4 bytes + cert_data: bytes # Variable - PEM certificate chain + + +@dataclass +class CertificationData: + """ + Certification Data from the quote (type 6 only). + + Contains QE report certification data with nested PCK cert chain. + Type 5 (direct PCK cert chain) is not supported as it lacks the + QE report needed for attestation key binding verification. + """ + cert_type: int # 2 bytes - Must be 6 (QE report certification data) + cert_data_size: int # 4 bytes - Size of certification data + qe_report_data: QeReportCertificationData + + def get_pck_chain(self) -> PckCertChainData: + """Get the PCK certificate chain from the QE report certification data.""" + return self.qe_report_data.pck_cert_chain_data + + +@dataclass +class SignedData: + """ + Signed Data section of the quote. + + Contains the quote signature, attestation public key, and + certification data (certificate chain). + """ + signature: bytes # 64 bytes - ECDSA signature (R || S) + attestation_key: bytes # 64 bytes - Raw ECDSA P-256 public key + certification_data: CertificationData + + def __str__(self) -> str: + return ( + f"SignedData(signature={self.signature[:8].hex()}..., " + f"attestation_key={self.attestation_key[:8].hex()}..., " + f"cert_type={self.certification_data.cert_type})" + ) + + +@dataclass +class QuoteV4: + """ + TDX Quote Version 4. + + The complete TDX attestation quote containing header, TD quote body, + and signed data with certification chain. + """ + header: TdxHeader + td_quote_body: TdQuoteBody + signed_data_size: int + signed_data: SignedData + extra_bytes: bytes = b"" # Any trailing bytes after signed data + + def __str__(self) -> str: + return ( + f"QuoteV4(\n" + f" header={self.header},\n" + f" td_quote_body={self.td_quote_body},\n" + f" signed_data_size={self.signed_data_size},\n" + f" signed_data={self.signed_data}\n" + f")" + ) + + def get_measurements(self) -> List[bytes]: + """Return the 5 TDX measurements: [MRTD, RTMR0, RTMR1, RTMR2, RTMR3].""" + return self.td_quote_body.get_measurements() + + def get_report_data(self) -> bytes: + """Return the 64-byte report data containing TLS key FP and HPKE key.""" + return self.td_quote_body.report_data + + +# ============================================================================= +# Parsing Functions +# ============================================================================= + + + +class TdxQuoteParseError(Exception): + """Raised when TDX quote parsing fails.""" + pass + + +def _parse_header(data: bytes) -> TdxHeader: + """ + Parse the 48-byte TDX quote header. + + Args: + data: 48 bytes of header data + + Returns: + Parsed TdxHeader + + Raises: + TdxQuoteParseError: If header is malformed + """ + if len(data) < HEADER_SIZE: + raise TdxQuoteParseError( + f"Header too short: {len(data)} bytes, expected {HEADER_SIZE}" + ) + + version = struct.unpack_from(" TdQuoteBody: + """ + Parse the 584-byte TD quote body. + + Args: + data: 584 bytes of TD quote body data + + Returns: + Parsed TdQuoteBody + + Raises: + TdxQuoteParseError: If body is malformed + """ + if len(data) < TD_QUOTE_BODY_SIZE: + raise TdxQuoteParseError( + f"TD quote body too short: {len(data)} bytes, expected {TD_QUOTE_BODY_SIZE}" + ) + + # Parse RTMRs (4 x 48 bytes) + rtmrs = [] + for i in range(RTMR_COUNT): + start = TD_RTMRS_START + (i * RTMR_SIZE) + end = start + RTMR_SIZE + rtmrs.append(data[start:end]) + + return TdQuoteBody( + tee_tcb_svn=data[TD_TEE_TCB_SVN_START:TD_TEE_TCB_SVN_END], + mr_seam=data[TD_MR_SEAM_START:TD_MR_SEAM_END], + mr_signer_seam=data[TD_MR_SIGNER_SEAM_START:TD_MR_SIGNER_SEAM_END], + seam_attributes=data[TD_SEAM_ATTRIBUTES_START:TD_SEAM_ATTRIBUTES_END], + td_attributes=data[TD_ATTRIBUTES_START:TD_ATTRIBUTES_END], + xfam=data[TD_XFAM_START:TD_XFAM_END], + mr_td=data[TD_MR_TD_START:TD_MR_TD_END], + mr_config_id=data[TD_MR_CONFIG_ID_START:TD_MR_CONFIG_ID_END], + mr_owner=data[TD_MR_OWNER_START:TD_MR_OWNER_END], + mr_owner_config=data[TD_MR_OWNER_CONFIG_START:TD_MR_OWNER_CONFIG_END], + rtmrs=rtmrs, + report_data=data[TD_REPORT_DATA_START:TD_REPORT_DATA_END], + ) + + +def _parse_qe_report(data: bytes) -> QeReport: + """ + Parse the 384-byte QE report. + + Args: + data: 384 bytes of QE report data + + Returns: + Parsed QeReport + + Raises: + TdxQuoteParseError: If report is malformed + """ + if len(data) < QE_REPORT_SIZE: + raise TdxQuoteParseError( + f"QE report too short: {len(data)} bytes, expected {QE_REPORT_SIZE}" + ) + + return QeReport( + cpu_svn=data[QE_CPU_SVN_START:QE_CPU_SVN_END], + misc_select=struct.unpack_from(" tuple[PckCertChainData, int]: + """ + Parse PCK certificate chain data. + + Args: + data: Raw bytes starting at PCK cert chain data + + Returns: + Tuple of (PckCertChainData, bytes_consumed) + + Raises: + TdxQuoteParseError: If data is malformed + """ + if len(data) < CERT_DATA_HEADER_SIZE: + raise TdxQuoteParseError("PCK cert chain data too short for header") + + cert_type = struct.unpack_from(" tuple[QeReportCertificationData, int]: + """ + Parse QE report certification data (type 6). + + Structure: + - QE Report: 384 bytes + - QE Report Signature: 64 bytes + - QE Auth Data Size: 2 bytes + - QE Auth Data: variable + - PCK Cert Chain Data: variable (nested type 5) + + Args: + data: Raw bytes starting at QE report certification data + + Returns: + Tuple of (QeReportCertificationData, bytes_consumed) + + Raises: + TdxQuoteParseError: If data is malformed + """ + offset = 0 + + # QE Report (384 bytes) + if len(data) < offset + QE_REPORT_SIZE: + raise TdxQuoteParseError("Data too short for QE report") + qe_report_raw = data[offset:offset + QE_REPORT_SIZE] + qe_report_parsed = _parse_qe_report(qe_report_raw) + offset += QE_REPORT_SIZE + + # QE Report Signature (64 bytes) + if len(data) < offset + SIGNATURE_SIZE: + raise TdxQuoteParseError("Data too short for QE report signature") + qe_report_signature = data[offset:offset + SIGNATURE_SIZE] + offset += SIGNATURE_SIZE + + # QE Auth Data Size (2 bytes) + Auth Data + if len(data) < offset + 2: + raise TdxQuoteParseError("Data too short for QE auth data size") + qe_auth_data_size = struct.unpack_from(" tuple[CertificationData, int]: + """ + Parse certification data from signed data section. + + Only type 6 (QE report certification data) is supported. Type 5 (direct + PCK cert chain) is rejected as it lacks the QE report needed for + attestation key binding verification. + + Args: + data: Raw bytes starting at certification data + + Returns: + Tuple of (CertificationData, bytes_consumed) + + Raises: + TdxQuoteParseError: If data is malformed or unsupported type + """ + if len(data) < CERT_DATA_HEADER_SIZE: + raise TdxQuoteParseError("Certification data too short for header") + + cert_type = struct.unpack_from(" SignedData: + """ + Parse the signed data section of the quote. + + Structure: + - Signature: 64 bytes (ECDSA R || S) + - Attestation Key: 64 bytes (raw P-256 public key) + - Certification Data: variable + + Args: + data: Raw bytes of signed data section + + Returns: + Parsed SignedData + + Raises: + TdxQuoteParseError: If data is malformed + """ + min_size = SIGNATURE_SIZE + ATTESTATION_KEY_SIZE + CERT_DATA_HEADER_SIZE + if len(data) < min_size: + raise TdxQuoteParseError( + f"Signed data too short: {len(data)} bytes, minimum {min_size}" + ) + + signature = data[SIGNED_DATA_SIGNATURE_START:SIGNED_DATA_SIGNATURE_END] + attestation_key = data[SIGNED_DATA_AK_START:SIGNED_DATA_AK_END] + + certification_data, _ = _parse_certification_data(data[SIGNED_DATA_CERT_DATA_START:]) + + return SignedData( + signature=signature, + attestation_key=attestation_key, + certification_data=certification_data, + ) + + +def _validate_header(header: TdxHeader) -> None: + """ + Validate TDX quote header fields. + + Args: + header: Parsed header to validate + + Raises: + TdxQuoteParseError: If validation fails + """ + if header.version == QUOTE_VERSION_V5: + raise TdxQuoteParseError( + "TDX QuoteV5 is not supported. Only QuoteV4 is implemented." + ) + + if header.version != QUOTE_VERSION_V4: + raise TdxQuoteParseError( + f"Unsupported quote version: {header.version}. Expected {QUOTE_VERSION_V4}." + ) + + if header.attestation_key_type != ATTESTATION_KEY_TYPE_ECDSA_P256: + raise TdxQuoteParseError( + f"Unsupported attestation key type: {header.attestation_key_type}. " + f"Expected {ATTESTATION_KEY_TYPE_ECDSA_P256} (ECDSA-P256)." + ) + + if header.tee_type != TEE_TDX: + raise TdxQuoteParseError( + f"Invalid TEE type: 0x{header.tee_type:x}. Expected 0x{TEE_TDX:x} (TDX)." + ) + + if header.qe_vendor_id != INTEL_QE_VENDOR_ID: + raise TdxQuoteParseError( + f"Unknown QE vendor ID: {header.qe_vendor_id.hex()}. " + f"Expected Intel QE: {INTEL_QE_VENDOR_ID.hex()}" + ) + + +def parse_quote(data: bytes) -> QuoteV4: + """ + Parse a TDX attestation quote from raw bytes. + + This is the main entry point for TDX quote parsing. It handles QuoteV4 + format and explicitly rejects QuoteV5. + + Args: + data: Raw quote bytes (typically from base64-decoded, gzip-decompressed + attestation document) + + Returns: + Parsed QuoteV4 structure + + Raises: + TdxQuoteParseError: If parsing fails or quote format is unsupported + + Example: + >>> raw_quote = base64.b64decode(attestation_doc) + >>> decompressed = gzip.decompress(raw_quote) + >>> quote = parse_quote(decompressed) + >>> measurements = quote.get_measurements() + """ + if len(data) < QUOTE_MIN_SIZE: + raise TdxQuoteParseError( + f"Quote too short: {len(data)} bytes, minimum {QUOTE_MIN_SIZE}" + ) + + # Parse header first to check version + header = _parse_header(data[QUOTE_HEADER_START:QUOTE_HEADER_END]) + _validate_header(header) + + # Parse TD quote body + td_quote_body = _parse_td_quote_body(data[QUOTE_BODY_START:QUOTE_BODY_END]) + + # Get signed data size and parse signed data + signed_data_size = struct.unpack_from( + " signed_data_end else b"" + + return QuoteV4( + header=header, + td_quote_body=td_quote_body, + signed_data_size=signed_data_size, + signed_data=signed_data, + extra_bytes=extra_bytes, + ) diff --git a/tests/test_tdx_abi.py b/tests/test_tdx_abi.py new file mode 100644 index 0000000..002a229 --- /dev/null +++ b/tests/test_tdx_abi.py @@ -0,0 +1,548 @@ +""" +Unit tests for TDX quote parsing (abi_tdx.py). +""" + +import pytest +import struct + +from tinfoil.attestation.abi_tdx import ( + # Constants + QUOTE_MIN_SIZE, + QUOTE_VERSION_V4, + QUOTE_VERSION_V5, + TEE_TDX, + ATTESTATION_KEY_TYPE_ECDSA_P256, + INTEL_QE_VENDOR_ID, + HEADER_SIZE, + TD_QUOTE_BODY_SIZE, + CERT_DATA_TYPE_PCK_CERT_CHAIN, + CERT_DATA_TYPE_QE_REPORT, + QE_REPORT_SIZE, + # Dataclasses + TdxHeader, + TdQuoteBody, + QeReport, + SignedData, + QuoteV4, + CertificationData, + PckCertChainData, + QeReportCertificationData, + # Parsing functions + parse_quote, + TdxQuoteParseError, + _parse_header, + _parse_td_quote_body, + _parse_qe_report, +) + + +# ============================================================================= +# Test Fixtures - Synthetic Quote Generation +# ============================================================================= + +def build_header( + version: int = QUOTE_VERSION_V4, + attestation_key_type: int = ATTESTATION_KEY_TYPE_ECDSA_P256, + tee_type: int = TEE_TDX, + reserved: bytes = b'\x00\x00\x00\x00', + qe_vendor_id: bytes = INTEL_QE_VENDOR_ID, + user_data: bytes = b'\x00' * 20, +) -> bytes: + """Build a synthetic TDX quote header (48 bytes). + + Note: Bytes 8-11 are reserved. Some older specs labeled these as + QE_SVN/PCE_SVN, but they are always zero in actual quotes. The real + SVN values come from PCK certificate extensions and QE Report. + """ + header = b'' + header += struct.pack(' bytes: + """Build a synthetic TD quote body (584 bytes).""" + body = b'' + body += tee_tcb_svn[:16].ljust(16, b'\x00') + body += mr_seam[:48].ljust(48, b'\x00') + body += mr_signer_seam[:48].ljust(48, b'\x00') + body += seam_attributes[:8].ljust(8, b'\x00') + body += td_attributes[:8].ljust(8, b'\x00') + body += xfam[:8].ljust(8, b'\x00') + body += mr_td[:48].ljust(48, b'\x00') + body += mr_config_id[:48].ljust(48, b'\x00') + body += mr_owner[:48].ljust(48, b'\x00') + body += mr_owner_config[:48].ljust(48, b'\x00') + body += rtmr0[:48].ljust(48, b'\x00') + body += rtmr1[:48].ljust(48, b'\x00') + body += rtmr2[:48].ljust(48, b'\x00') + body += rtmr3[:48].ljust(48, b'\x00') + body += report_data[:64].ljust(64, b'\x00') + assert len(body) == TD_QUOTE_BODY_SIZE + return body + + +def build_qe_report( + cpu_svn: bytes = b'\x00' * 16, + misc_select: int = 0, + attributes: bytes = b'\x00' * 16, + mr_enclave: bytes = b'\xee' * 32, + mr_signer: bytes = b'\xff' * 32, + isv_prod_id: int = 1, + isv_svn: int = 2, + report_data: bytes = b'\x00' * 64, +) -> bytes: + """Build a synthetic QE report (384 bytes).""" + report = b'' + report += cpu_svn[:16].ljust(16, b'\x00') # 0x00-0x10 + report += struct.pack(' bytes: + """Build PCK cert chain data (type 5).""" + return struct.pack(' bytes: + """Build signed data with type 6 (QE report certification data).""" + if qe_report is None: + qe_report = build_qe_report() + + # QE report certification data content + qe_cert_content = qe_report + qe_report_signature + struct.pack(' bytes: + """Build a complete synthetic TDX quote.""" + if header is None: + header = build_header() + if body is None: + body = build_td_quote_body() + if signed_data is None: + signed_data = build_signed_data() + + return header + body + struct.pack(' Date: Mon, 23 Feb 2026 16:14:28 -0500 Subject: [PATCH 02/19] feat: add TDX cryptographic verification (verify_tdx.py) Implement PCK certificate chain verification, ECDSA quote signature verification, QE report signature verification, and attestation key binding (AK hash in QE report data). All crypto operations use the cryptography library with proper algorithm constraints. --- src/tinfoil/attestation/verify_tdx.py | 330 ++++++++++++++++++ tests/test_tdx_verify.py | 465 ++++++++++++++++++++++++++ 2 files changed, 795 insertions(+) create mode 100644 src/tinfoil/attestation/verify_tdx.py create mode 100644 tests/test_tdx_verify.py diff --git a/src/tinfoil/attestation/verify_tdx.py b/src/tinfoil/attestation/verify_tdx.py new file mode 100644 index 0000000..0c11891 --- /dev/null +++ b/src/tinfoil/attestation/verify_tdx.py @@ -0,0 +1,330 @@ +""" +TDX Quote cryptographic verification. + +This module implements the cryptographic verification of TDX attestation quotes, +including signature verification and certificate chain validation. + +Verification flow: +1. Extract PCK certificate chain from quote +2. Verify PCK chain against Intel SGX Root CA +3. Verify quote signature (ECDSA P-256 over Header || TdQuoteBody) +4. Verify QE report signature using PCK leaf certificate +5. Verify QE report data binding (attestation key hash) +""" + +import hashlib +from dataclasses import dataclass + + +from cryptography import x509 +from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature +from cryptography.hazmat.primitives.asymmetric.utils import Prehashed + +from .abi_tdx import ( + QuoteV4, + QUOTE_HEADER_START, + QUOTE_BODY_END, + SIGNATURE_SIZE, + ATTESTATION_KEY_SIZE, + ECDSA_P256_COMPONENT_SIZE, + SHA256_HASH_SIZE, + PCK_CERT_CHAIN_COUNT, +) +from .cert_utils import parse_pem_chain, verify_intel_chain, CertificateChainError + + +class TdxVerificationError(Exception): + """Raised when TDX quote verification fails.""" + pass + + +@dataclass +class PCKCertificateChain: + """ + PCK Certificate Chain extracted from the quote. + + Contains three certificates: + - PCK leaf certificate (signs the QE report) + - Intermediate CA certificate + - Root CA certificate (should match Intel SGX Root CA) + """ + pck_cert: x509.Certificate # PCK Leaf certificate + intermediate_cert: x509.Certificate # Intermediate CA certificate + root_cert: x509.Certificate # Root CA certificate + + +def extract_pck_cert_chain(quote: QuoteV4) -> PCKCertificateChain: + """ + Extract the PCK certificate chain from the quote. + + The certificate chain is embedded in the quote's certification data + as concatenated PEM certificates: PCK Leaf || Intermediate CA || Root CA. + + Args: + quote: Parsed TDX QuoteV4 + + Returns: + PCKCertificateChain with three certificates + + Raises: + TdxVerificationError: If certificate chain is missing or malformed + """ + pck_chain_data = quote.signed_data.certification_data.get_pck_chain() + cert_pem = pck_chain_data.cert_data + if not cert_pem: + raise TdxVerificationError("PCK certificate chain is empty") + + # Parse concatenated PEM certificates + try: + certs = parse_pem_chain(cert_pem) + except CertificateChainError as e: + raise TdxVerificationError(f"Failed to parse PCK certificate chain: {e}") + + if len(certs) != PCK_CERT_CHAIN_COUNT: + raise TdxVerificationError( + f"PCK certificate chain should contain {PCK_CERT_CHAIN_COUNT} certificates, got {len(certs)}" + ) + + return PCKCertificateChain( + pck_cert=certs[0], + intermediate_cert=certs[1], + root_cert=certs[2], + ) + + +def verify_pck_chain(chain: PCKCertificateChain) -> None: + """ + Verify the PCK certificate chain against Intel SGX Root CA. + + Verification steps: + 1. Verify root cert matches embedded Intel SGX Root CA + 2. Verify certificate chain (validity + signatures) using cryptography library + + Args: + chain: PCK certificate chain from quote + + Raises: + TdxVerificationError: If chain verification fails + """ + certs = [chain.pck_cert, chain.intermediate_cert, chain.root_cert] + try: + verify_intel_chain(certs, "PCK certificate chain") + except CertificateChainError as e: + raise TdxVerificationError(str(e)) + + +def verify_quote_signature(quote: QuoteV4, raw_quote: bytes) -> None: + """ + Verify the quote signature using the attestation key. + + The signature covers SHA256(Header || TdQuoteBody). + + Args: + quote: Parsed QuoteV4 + raw_quote: Original raw quote bytes + + Raises: + TdxVerificationError: If signature verification fails + """ + # Get attestation key (64 bytes = raw P-256 point) + attestation_key_bytes = quote.signed_data.attestation_key + public_key = _bytes_to_p256_pubkey(attestation_key_bytes) + + # Get signature (64 bytes = R || S) + signature_bytes = quote.signed_data.signature + signature_der = _signature_to_der(signature_bytes) + + # Message = Header || TdQuoteBody (bytes 0x00 to 0x278) + message = raw_quote[QUOTE_HEADER_START:QUOTE_BODY_END] + message_hash = hashlib.sha256(message).digest() + + # Verify ECDSA signature over pre-hashed message + # Use Prehashed to avoid double-hashing (we already computed SHA256) + try: + public_key.verify( + signature_der, + message_hash, + ec.ECDSA(Prehashed(hashes.SHA256())) + ) + except InvalidSignature: + raise TdxVerificationError( + "Quote signature verification failed: signature does not match" + ) + + +def verify_qe_report_signature(quote: QuoteV4, pck_cert: x509.Certificate) -> None: + """ + Verify the QE report signature using the PCK leaf certificate. + + Caller must ensure quote.signed_data.certification_data.qe_report_data + is not None before calling this function. + + Args: + quote: Parsed QuoteV4 + pck_cert: PCK leaf certificate from the chain + + Raises: + TdxVerificationError: If signature verification fails + """ + qe_report_data = quote.signed_data.certification_data.qe_report_data + + # Get QE report and signature + qe_report = qe_report_data.qe_report # 384 bytes + qe_signature = qe_report_data.qe_report_signature # 64 bytes + + # Convert signature to DER format + signature_der = _signature_to_der(qe_signature) + + # Verify using PCK certificate's public key + try: + pck_public_key = pck_cert.public_key() + pck_public_key.verify(signature_der, qe_report, ec.ECDSA(hashes.SHA256())) + except InvalidSignature: + raise TdxVerificationError( + "QE report signature verification failed using PCK certificate" + ) + except (TypeError, AttributeError, ValueError, UnsupportedAlgorithm) as e: + raise TdxVerificationError( + f"PCK certificate has unexpected key type: {e}" + ) + + +def verify_qe_report_data_binding(quote: QuoteV4) -> None: + """ + Verify that the QE report data binds to the attestation key. + + This is a CRITICAL security check. The QE report's report_data field + must contain SHA256(attestation_key || qe_auth_data) padded to 64 bytes. + Without this check, an attacker could substitute a different attestation key. + + Caller must ensure quote.signed_data.certification_data.qe_report_data + is not None before calling this function. + + Args: + quote: Parsed QuoteV4 + + Raises: + TdxVerificationError: If binding verification fails + """ + qe_report_data = quote.signed_data.certification_data.qe_report_data + + # Get components + attestation_key = quote.signed_data.attestation_key # 64 bytes + qe_auth_data = qe_report_data.qe_auth_data # Variable + qe_report_data_field = qe_report_data.qe_report_parsed.report_data # 64 bytes + + # Compute expected: SHA256(attestation_key || qe_auth_data) || zeros + data_to_hash = attestation_key + qe_auth_data + expected_hash = hashlib.sha256(data_to_hash).digest() + expected_report_data = expected_hash + b'\x00' * SHA256_HASH_SIZE + + # Both values are public attestation data; no timing side-channel concern. + if qe_report_data_field != expected_report_data: + raise TdxVerificationError( + "QE report data binding verification failed: " + "SHA256(attestation_key || auth_data) does not match QE report data. " + "The attestation key may have been tampered with." + ) + + +def _bytes_to_p256_pubkey(key_bytes: bytes) -> ec.EllipticCurvePublicKey: + """ + Convert raw 64-byte P-256 public key to cryptography public key object. + + The raw format is X || Y (32 bytes each). + + Args: + key_bytes: 64 bytes representing X || Y coordinates + + Returns: + EllipticCurvePublicKey object + + Raises: + TdxVerificationError: If key format is invalid + """ + if len(key_bytes) != ATTESTATION_KEY_SIZE: + raise TdxVerificationError( + f"Attestation key is {len(key_bytes)} bytes, expected {ATTESTATION_KEY_SIZE}" + ) + + # Convert to uncompressed point format (0x04 || X || Y) + uncompressed = b'\x04' + key_bytes + + try: + return ec.EllipticCurvePublicKey.from_encoded_point( + ec.SECP256R1(), uncompressed + ) + except (ValueError, TypeError) as e: + raise TdxVerificationError(f"Invalid attestation key: {e}") + + +def _signature_to_der(sig_bytes: bytes) -> bytes: + """ + Convert raw R||S signature to DER format. + + TDX signatures are 64 bytes: R (32 bytes) || S (32 bytes). + DER format is required by cryptography library. + + Args: + sig_bytes: 64-byte raw signature + + Returns: + DER-encoded signature + + Raises: + TdxVerificationError: If signature format is invalid + """ + if len(sig_bytes) != SIGNATURE_SIZE: + raise TdxVerificationError( + f"Signature is {len(sig_bytes)} bytes, expected {SIGNATURE_SIZE}" + ) + + r = int.from_bytes(sig_bytes[0:ECDSA_P256_COMPONENT_SIZE], byteorder='big') + s = int.from_bytes(sig_bytes[ECDSA_P256_COMPONENT_SIZE:SIGNATURE_SIZE], byteorder='big') + + return encode_dss_signature(r, s) + + +def verify_tdx_quote(quote: QuoteV4, raw_quote: bytes) -> PCKCertificateChain: + """ + Perform full cryptographic verification of a TDX quote. + + This is the main entry point for TDX verification. It performs all + security-critical checks in order: + + 1. Extract and verify PCK certificate chain + 2. Verify quote signature using attestation key + 3. Verify QE report signature using PCK certificate + 4. Verify QE report data binding (critical security check) + + Args: + quote: Parsed QuoteV4 structure + raw_quote: Original raw quote bytes (needed for signature verification) + + Returns: + PCKCertificateChain on success (for further use in TCB checks) + + Raises: + TdxVerificationError: If any verification step fails + """ + # Step 1: Extract and verify PCK certificate chain + chain = extract_pck_cert_chain(quote) + verify_pck_chain(chain) + + # Step 2: Verify quote signature + verify_quote_signature(quote, raw_quote) + + # Step 3 & 4 require QE report data + if quote.signed_data.certification_data.qe_report_data is None: + raise TdxVerificationError("QE report data is missing") + + # Step 3: Verify QE report signature + verify_qe_report_signature(quote, chain.pck_cert) + + # Step 4: Verify QE report data binding (CRITICAL) + verify_qe_report_data_binding(quote) + + return chain diff --git a/tests/test_tdx_verify.py b/tests/test_tdx_verify.py new file mode 100644 index 0000000..fb8c9f1 --- /dev/null +++ b/tests/test_tdx_verify.py @@ -0,0 +1,465 @@ +""" +Unit tests for TDX quote verification (verify_tdx.py). +""" + +import hashlib +import pytest +import struct + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, Prehashed + +from tinfoil.attestation.abi_tdx import ( + parse_quote, + QUOTE_HEADER_START, + QUOTE_BODY_END, + QE_REPORT_SIZE, + HEADER_SIZE, + TD_QUOTE_BODY_SIZE, + CERT_DATA_TYPE_QE_REPORT, + CERT_DATA_TYPE_PCK_CERT_CHAIN, +) +from tinfoil.attestation.verify_tdx import ( + TdxVerificationError, + PCKCertificateChain, + extract_pck_cert_chain, + verify_pck_chain, + verify_quote_signature, + verify_qe_report_signature, + verify_qe_report_data_binding, + verify_tdx_quote, + _bytes_to_p256_pubkey, + _signature_to_der, +) +from tinfoil.attestation.cert_utils import parse_pem_chain +from tinfoil.attestation.intel_root_ca import get_intel_root_ca, INTEL_SGX_ROOT_CA_PEM + + +# ============================================================================= +# Test Fixtures - Cryptographic Key Generation +# ============================================================================= + +def generate_p256_keypair(): + """Generate a P-256 key pair for testing.""" + private_key = ec.generate_private_key(ec.SECP256R1()) + public_key = private_key.public_key() + return private_key, public_key + + +def public_key_to_raw_bytes(public_key: ec.EllipticCurvePublicKey) -> bytes: + """Convert P-256 public key to raw 64-byte format (X || Y).""" + # Get uncompressed point (0x04 || X || Y) + uncompressed = public_key.public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint, + ) + # Strip the 0x04 prefix + return uncompressed[1:] + + +def sign_message(private_key: ec.EllipticCurvePrivateKey, message: bytes) -> bytes: + """Sign a message and return raw R||S signature (64 bytes). + + This hashes the message with SHA256 before signing. + """ + # Sign with SHA256 + der_signature = private_key.sign(message, ec.ECDSA(hashes.SHA256())) + + # Convert DER to raw R||S + r, s = decode_dss_signature(der_signature) + return r.to_bytes(32, byteorder='big') + s.to_bytes(32, byteorder='big') + + +def sign_prehashed(private_key: ec.EllipticCurvePrivateKey, digest: bytes) -> bytes: + """Sign a pre-hashed digest and return raw R||S signature (64 bytes). + + Use this when the message has already been hashed with SHA256. + """ + # Sign the pre-hashed digest + der_signature = private_key.sign(digest, ec.ECDSA(Prehashed(hashes.SHA256()))) + + # Convert DER to raw R||S + r, s = decode_dss_signature(der_signature) + return r.to_bytes(32, byteorder='big') + s.to_bytes(32, byteorder='big') + + +# ============================================================================= +# Test Fixtures - Quote Building with Real Signatures +# ============================================================================= + +def build_header_bytes() -> bytes: + """Build a valid TDX header. + + Note: Bytes 8-11 are reserved. Some older specs labeled these as + QE_SVN/PCE_SVN, but they are always zero in actual quotes. The real + SVN values come from PCK certificate extensions and QE Report. + """ + header = b'' + header += struct.pack(' bytes: + """Build a valid TdQuoteBody.""" + body = b'' + body += b'\x03' + b'\x00' * 15 # tee_tcb_svn + body += b'\xaa' * 48 # mr_seam + body += b'\x00' * 48 # mr_signer_seam + body += b'\x00' * 8 # seam_attributes + body += b'\x00\x00\x00\x10\x00\x00\x00\x00' # td_attributes + body += b'\xe7\x02\x06\x00\x00\x00\x00\x00' # xfam + body += b'\x11' * 48 # mr_td + body += b'\x00' * 48 # mr_config_id + body += b'\x00' * 48 # mr_owner + body += b'\x00' * 48 # mr_owner_config + body += b'\x22' * 48 # rtmr0 + body += b'\x33' * 48 # rtmr1 + body += b'\x44' * 48 # rtmr2 + body += b'\x00' * 48 # rtmr3 + body += b'\xab' * 32 + b'\xcd' * 32 # report_data + assert len(body) == TD_QUOTE_BODY_SIZE + return body + + +def build_qe_report_bytes(report_data: bytes = None) -> bytes: + """Build a QE report with specified report_data.""" + if report_data is None: + report_data = b'\x00' * 64 + + report = b'' + report += b'\x00' * 16 # cpu_svn + report += struct.pack(' bytes: + """Build the signed data section of the quote.""" + # PCK cert chain (type 5) + pck_chain = struct.pack(' bytes: + """Generate a self-signed certificate for testing.""" + from cryptography.x509.oid import NameOID + from cryptography.hazmat.primitives.asymmetric import ec + import datetime + + private_key = ec.generate_private_key(ec.SECP256R1()) + + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, cn), + ]) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=365)) + .sign(private_key, hashes.SHA256()) + ) + + return cert.public_bytes(serialization.Encoding.PEM) + + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + +class TestBytesToP256PubKey: + """Test _bytes_to_p256_pubkey helper.""" + + def test_valid_key(self): + """Test converting valid key bytes.""" + _, public_key = generate_p256_keypair() + raw_bytes = public_key_to_raw_bytes(public_key) + + result = _bytes_to_p256_pubkey(raw_bytes) + assert isinstance(result, ec.EllipticCurvePublicKey) + + def test_wrong_size(self): + """Test that wrong size raises error.""" + with pytest.raises(TdxVerificationError, match="expected 64"): + _bytes_to_p256_pubkey(b'\x00' * 63) + + +class TestSignatureToDer: + """Test _signature_to_der helper.""" + + def test_valid_signature(self): + """Test converting valid signature.""" + private_key, _ = generate_p256_keypair() + raw_sig = sign_message(private_key, b"test message") + + der_sig = _signature_to_der(raw_sig) + assert isinstance(der_sig, bytes) + assert len(der_sig) > 64 # DER is longer due to encoding + + def test_wrong_size(self): + """Test that wrong size raises error.""" + with pytest.raises(TdxVerificationError, match="expected 64"): + _signature_to_der(b'\x00' * 63) + + +class TestParsePemChain: + """Test parse_pem_chain helper.""" + + def test_parse_single_cert(self): + """Test parsing single certificate.""" + cert_pem = _generate_self_signed_cert_pem("Test") + certs = parse_pem_chain(cert_pem) + assert len(certs) == 1 + + def test_parse_multiple_certs(self): + """Test parsing multiple concatenated certificates.""" + cert1 = _generate_self_signed_cert_pem("Cert1") + cert2 = _generate_self_signed_cert_pem("Cert2") + cert3 = _generate_self_signed_cert_pem("Cert3") + + chain = cert1 + cert2 + cert3 + certs = parse_pem_chain(chain) + assert len(certs) == 3 + + def test_parse_with_whitespace(self): + """Test parsing with leading/trailing whitespace.""" + cert_pem = b'\n\n' + _generate_self_signed_cert_pem("Test") + b'\n\n' + certs = parse_pem_chain(cert_pem) + assert len(certs) == 1 + + +# ============================================================================= +# Quote Signature Verification Tests +# ============================================================================= + +class TestVerifyQuoteSignature: + """Test quote signature verification.""" + + def test_valid_signature(self): + """Test verification of valid quote signature.""" + raw_quote, attest_priv, _ = build_signed_quote_with_keys() + quote = parse_quote(raw_quote) + + # Should not raise + verify_quote_signature(quote, raw_quote) + + def test_tampered_body(self): + """Test that tampered body fails verification.""" + raw_quote, _, _ = build_signed_quote_with_keys() + + # Tamper with the body + tampered = bytearray(raw_quote) + tampered[100] ^= 0xFF # Flip a byte in the body + tampered = bytes(tampered) + + quote = parse_quote(tampered) + + with pytest.raises(TdxVerificationError, match="signature verification failed"): + verify_quote_signature(quote, tampered) + + +# ============================================================================= +# QE Report Data Binding Tests +# ============================================================================= + +class TestVerifyQeReportDataBinding: + """Test QE report data binding verification.""" + + def test_valid_binding(self): + """Test verification of valid binding.""" + raw_quote, _, _ = build_signed_quote_with_keys() + quote = parse_quote(raw_quote) + + # Should not raise + verify_qe_report_data_binding(quote) + + def test_tampered_auth_data(self): + """Test that wrong auth data fails binding check.""" + # Build quote with specific auth data + attest_priv, attest_pub = generate_p256_keypair() + + header = build_header_bytes() + body = build_body_bytes() + message_hash = hashlib.sha256(header + body).digest() + quote_signature = sign_prehashed(attest_priv, message_hash) + + attestation_key_raw = public_key_to_raw_bytes(attest_pub) + + # Use correct auth data for hash calculation + correct_auth_data = b'\x88' * 32 + qe_report_data_hash = hashlib.sha256(attestation_key_raw + correct_auth_data).digest() + qe_report_data = qe_report_data_hash + b'\x00' * 32 + + qe_report = build_qe_report_bytes(qe_report_data) + qe_priv, _ = generate_p256_keypair() + qe_signature = sign_message(qe_priv, qe_report) + + # But embed WRONG auth data in the quote + wrong_auth_data = b'\x99' * 32 # Different! + + cert_chain = ( + _generate_self_signed_cert_pem("PCK") + + _generate_self_signed_cert_pem("Int") + + INTEL_SGX_ROOT_CA_PEM + ) + + signed_data = _build_signed_data_bytes( + signature=quote_signature, + attestation_key=attestation_key_raw, + qe_report=qe_report, + qe_signature=qe_signature, + qe_auth_data=wrong_auth_data, # Wrong! + cert_chain_pem=cert_chain, + ) + + raw_quote = header + body + struct.pack(' Date: Mon, 23 Feb 2026 16:14:35 -0500 Subject: [PATCH 03/19] feat: add TDX policy validation (validate_tdx.py) Implement XFAM, TD_ATTRIBUTES, SEAM attributes, and measurement policy validation with frozen dataclass options. Unconditionally rejects debug TDs per Intel DCAP spec Section 2.3.2. Supports exact match and allowlist modes for MR_TD and MR_SEAM. --- src/tinfoil/attestation/validate_tdx.py | 423 ++++++++++++++++++++++++ tests/test_validate_tdx.py | 421 +++++++++++++++++++++++ 2 files changed, 844 insertions(+) create mode 100644 src/tinfoil/attestation/validate_tdx.py create mode 100644 tests/test_validate_tdx.py diff --git a/src/tinfoil/attestation/validate_tdx.py b/src/tinfoil/attestation/validate_tdx.py new file mode 100644 index 0000000..ac4bb0b --- /dev/null +++ b/src/tinfoil/attestation/validate_tdx.py @@ -0,0 +1,423 @@ +""" +TDX Policy Validation. + +This module provides pure policy validation functions for TDX quote fields. + +This module contains: +- Policy validation functions (validate_xfam, validate_td_attributes, etc.) +- Policy option dataclasses (PolicyOptions, HeaderOptions, TdQuoteBodyOptions) +- Policy-related constants (XFAM_FIXED*, TD_ATTRIBUTES_FIXED*, etc.) +""" + +import struct +from dataclasses import dataclass, field +from typing import Optional + +from .abi_tdx import ( + QuoteV4, + MR_SEAM_SIZE, + TD_ATTRIBUTES_SIZE, + XFAM_SIZE, + MR_TD_SIZE, + MR_CONFIG_ID_SIZE, + MR_OWNER_SIZE, + MR_OWNER_CONFIG_SIZE, + RTMR_SIZE, + RTMR_COUNT, + REPORT_DATA_SIZE, + QE_VENDOR_ID_SIZE, + SEAM_ATTRIBUTES_SIZE, + MR_SIGNER_SEAM_SIZE, + TEE_TCB_SVN_SIZE, +) + + +# ============================================================================= +# Policy Validation Error +# ============================================================================= + +class TdxValidationError(Exception): + """Raised when TDX policy validation fails.""" + pass + + +# ============================================================================= +# Policy Validation Constants +# ============================================================================= + +# XFAM fixed bit constraints +# If bit X is 1 in XFAM_FIXED1, it must be 1 in any XFAM +XFAM_FIXED1 = 0x00000003 +# If bit X is 0 in XFAM_FIXED0, it must be 0 in any XFAM +XFAM_FIXED0 = 0x0006DBE7 + +# TD_ATTRIBUTES fixed bit constraints +# No FIXED1 bits for TD_ATTRIBUTES currently (no bits are mandatory-set). + +# TD_ATTRIBUTES bit definitions +TD_ATTRIBUTES_DEBUG_BIT = 0x1 # Bit 0: DEBUG mode +TD_ATTRIBUTES_SEPT_VE_DIS = 1 << 28 # Bit 28: Disable EPT violation #VE +TD_ATTRIBUTES_PKS = 1 << 30 # Bit 30: Supervisor Protection Keys +TD_ATTRIBUTES_PERFMON = 1 << 63 # Bit 63: Performance monitoring + +# If bit X is 0 in TD_ATTRIBUTES_FIXED0, it must be 0 in any TD_ATTRIBUTES +# Supported bits: DEBUG, SEPT_VE_DIS, PKS, PERFMON +TD_ATTRIBUTES_FIXED0 = ( + TD_ATTRIBUTES_DEBUG_BIT | + TD_ATTRIBUTES_SEPT_VE_DIS | + TD_ATTRIBUTES_PKS | + TD_ATTRIBUTES_PERFMON +) + + + +# ============================================================================= +# Policy Validation Options +# ============================================================================= + +@dataclass(frozen=True) +class HeaderOptions: + """ + Validation options for TDX quote header fields. + Mirrors Go's validate.HeaderOptions struct. + + All fields are optional - set to None to skip the check. + """ + # Expected QE_VENDOR_ID (16 bytes, not checked if None) + qe_vendor_id: Optional[bytes] = None + + +@dataclass(frozen=True) +class TdQuoteBodyOptions: + """ + Validation options for TDX quote body fields. + Mirrors Go's validate.TdQuoteBodyOptions struct. + + All fields are optional - set to None to skip the check. + """ + # Minimum TEE TCB SVN (16 bytes, component-wise comparison) + minimum_tee_tcb_svn: Optional[bytes] = None + # Expected MR_SEAM (48 bytes) + mr_seam: Optional[bytes] = None + # Expected TD_ATTRIBUTES (8 bytes) + td_attributes: Optional[bytes] = None + # Expected XFAM (8 bytes) + xfam: Optional[bytes] = None + # Expected MR_TD (48 bytes) + mr_td: Optional[bytes] = None + # Expected MR_CONFIG_ID (48 bytes) + mr_config_id: Optional[bytes] = None + # Expected MR_OWNER (48 bytes) + mr_owner: Optional[bytes] = None + # Expected MR_OWNER_CONFIG (48 bytes) + mr_owner_config: Optional[bytes] = None + # Expected RTMRs (list of 4 x 48 bytes) + rtmrs: Optional[tuple[Optional[bytes], ...]] = None + # Expected REPORT_DATA (64 bytes) + report_data: Optional[bytes] = None + # Any permitted MR_TD values (list of 48-byte values) + any_mr_td: Optional[tuple[Optional[bytes], ...]] = None + # Any permitted MR_SEAM values (list of 48-byte values) + any_mr_seam: Optional[tuple[Optional[bytes], ...]] = None + + +@dataclass(frozen=True) +class PolicyOptions: + """ + Complete validation options for TDX quote policy validation. + Mirrors Go's validate.Options struct. + """ + header: HeaderOptions = field(default_factory=HeaderOptions) + td_quote_body: TdQuoteBodyOptions = field(default_factory=TdQuoteBodyOptions) + + +# ============================================================================= +# Policy Validation Helper Functions +# ============================================================================= + +def _check_option_length(name: str, expected: int, value: Optional[bytes]) -> None: + """Check field length if value is provided.""" + if value is not None and len(value) != expected: + raise TdxValidationError( + f"Option '{name}' length is {len(value)}, expected {expected}" + ) + + +def _check_options_lengths(options: PolicyOptions) -> None: + """Validate all option field lengths.""" + h = options.header + t = options.td_quote_body + + _check_option_length("qe_vendor_id", QE_VENDOR_ID_SIZE, h.qe_vendor_id) + _check_option_length("minimum_tee_tcb_svn", TEE_TCB_SVN_SIZE, t.minimum_tee_tcb_svn) + _check_option_length("mr_seam", MR_SEAM_SIZE, t.mr_seam) + _check_option_length("td_attributes", TD_ATTRIBUTES_SIZE, t.td_attributes) + _check_option_length("xfam", XFAM_SIZE, t.xfam) + _check_option_length("mr_td", MR_TD_SIZE, t.mr_td) + _check_option_length("mr_config_id", MR_CONFIG_ID_SIZE, t.mr_config_id) + _check_option_length("mr_owner", MR_OWNER_SIZE, t.mr_owner) + _check_option_length("mr_owner_config", MR_OWNER_CONFIG_SIZE, t.mr_owner_config) + _check_option_length("report_data", REPORT_DATA_SIZE, t.report_data) + + if t.rtmrs is not None: + if len(t.rtmrs) != RTMR_COUNT: + raise TdxValidationError( + f"Option 'rtmrs' has {len(t.rtmrs)} entries, expected {RTMR_COUNT}" + ) + for i, rtmr in enumerate(t.rtmrs): + if rtmr is not None and len(rtmr) != RTMR_SIZE: + raise TdxValidationError( + f"Option 'rtmrs[{i}]' length is {len(rtmr)}, expected {RTMR_SIZE}" + ) + + if t.any_mr_td is not None: + if not any(v is not None for v in t.any_mr_td): + raise TdxValidationError("Option 'any_mr_td' contains no non-None entries") + for i, mr_td in enumerate(t.any_mr_td): + if mr_td is not None and len(mr_td) != MR_TD_SIZE: + raise TdxValidationError( + f"Option 'any_mr_td[{i}]' length is {len(mr_td)}, expected {MR_TD_SIZE}" + ) + + if t.any_mr_seam is not None: + if not any(v is not None for v in t.any_mr_seam): + raise TdxValidationError("Option 'any_mr_seam' contains no non-None entries") + for i, mr_seam in enumerate(t.any_mr_seam): + if mr_seam is not None and len(mr_seam) != MR_SEAM_SIZE: + raise TdxValidationError( + f"Option 'any_mr_seam[{i}]' length is {len(mr_seam)}, expected {MR_SEAM_SIZE}" + ) + + # Mutual exclusion: exact-match and allowlist cannot both be set + if t.mr_td is not None and t.any_mr_td is not None and len(t.any_mr_td) > 0: + raise TdxValidationError( + "Cannot set both 'mr_td' and 'any_mr_td' - use one or the other" + ) + if t.mr_seam is not None and t.any_mr_seam is not None and len(t.any_mr_seam) > 0: + raise TdxValidationError( + "Cannot set both 'mr_seam' and 'any_mr_seam' - use one or the other" + ) + + +def _byte_check( + field_name: str, + given: bytes, + expected: Optional[bytes], +) -> None: + """Check exact byte match if expected is provided.""" + if expected is None: + return # Skip check + + if given != expected: + raise TdxValidationError( + f"Quote field {field_name} is {given.hex()}, expected {expected.hex()}" + ) + + +def _is_svn_higher_or_equal(quote_svn: bytes, min_svn: Optional[bytes]) -> bool: + """Component-wise SVN comparison. Returns True if min_svn is None.""" + if min_svn is None: + return True + if len(quote_svn) != len(min_svn): + return False + for q, m in zip(quote_svn, min_svn): + if q < m: + return False + return True + + +# ============================================================================= +# Policy Validation Core Functions +# ============================================================================= + +def validate_xfam(xfam: bytes) -> None: + """ + Validate XFAM fixed bit constraints. + + Args: + xfam: 8-byte XFAM from quote + + Raises: + TdxValidationError: If fixed bit constraints violated + """ + if len(xfam) != XFAM_SIZE: + raise TdxValidationError(f"XFAM size is {len(xfam)}, expected {XFAM_SIZE}") + + value = struct.unpack(' None: + """ + Validate TD_ATTRIBUTES fixed bit constraints. + + Per Intel DCAP spec Section 2.3.2: "Verify that all TD Under Debug + flags (i.e., the TDATTRIBUTES.TUD field in the TD Quote Body) are + set to zero." + + Args: + td_attributes: 8-byte TD_ATTRIBUTES from quote + + Raises: + TdxValidationError: If validation fails + """ + if len(td_attributes) != TD_ATTRIBUTES_SIZE: + raise TdxValidationError( + f"TD_ATTRIBUTES size is {len(td_attributes)}, expected {TD_ATTRIBUTES_SIZE}" + ) + + value = struct.unpack(' None: + """Validate that a field is all zeros with the expected size.""" + if len(value) != expected_size: + raise TdxValidationError( + f"{name} size is {len(value)}, expected {expected_size}" + ) + if value != b'\x00' * expected_size: + raise TdxValidationError( + f"{name} must be zero for {context}, got {value.hex()}" + ) + + +def validate_seam_attributes(seam_attributes: bytes) -> None: + """ + Validate SEAMATTRIBUTES is zero (required for TDX 1.0/1.5). + + Per Intel TDX DCAP Quoting Library API section 2.3.2. + + Args: + seam_attributes: 8-byte SEAMATTRIBUTES from quote + + Raises: + TdxValidationError: If not zero + """ + _validate_zero_field("SEAMATTRIBUTES", seam_attributes, SEAM_ATTRIBUTES_SIZE, "TDX 1.0/1.5") + + +def validate_mr_signer_seam(mr_signer_seam: bytes) -> None: + """ + Validate MRSIGNERSEAM is zero (required for Intel TDX Module). + + Per Intel TDX DCAP Quoting Library API section 2.3.2. + + Args: + mr_signer_seam: 48-byte MRSIGNERSEAM from quote + + Raises: + TdxValidationError: If not zero + """ + _validate_zero_field("MRSIGNERSEAM", mr_signer_seam, MR_SIGNER_SEAM_SIZE, "Intel TDX Module") + + +def _validate_exact_byte_matches(quote: QuoteV4, options: PolicyOptions) -> None: + """Validate exact byte matches for configured fields.""" + t = options.td_quote_body + h = options.header + body = quote.td_quote_body + + _byte_check("MR_SEAM", body.mr_seam, t.mr_seam) + _byte_check("TD_ATTRIBUTES", body.td_attributes, t.td_attributes) + _byte_check("XFAM", body.xfam, t.xfam) + _byte_check("MR_TD", body.mr_td, t.mr_td) + _byte_check("MR_CONFIG_ID", body.mr_config_id, t.mr_config_id) + _byte_check("MR_OWNER", body.mr_owner, t.mr_owner) + _byte_check("MR_OWNER_CONFIG", body.mr_owner_config, t.mr_owner_config) + _byte_check("REPORT_DATA", body.report_data, t.report_data) + _byte_check("QE_VENDOR_ID", quote.header.qe_vendor_id, h.qe_vendor_id) + + # RTMR checks + if t.rtmrs is not None: + for i, (given, expected) in enumerate(zip(body.rtmrs, t.rtmrs)): + if expected is not None: + _byte_check(f"RTMR[{i}]", given, expected) + + # Any MR_TD check (at least one must match) + if t.any_mr_td is not None and len(t.any_mr_td) > 0: + mr_td = body.mr_td + if not any(mr_td == allowed for allowed in t.any_mr_td if allowed is not None): + raise TdxValidationError( + f"MR_TD {mr_td.hex()} does not match any allowed value" + ) + + # Any MR_SEAM check (at least one must match) + if t.any_mr_seam is not None and len(t.any_mr_seam) > 0: + mr_seam = body.mr_seam + if not any(mr_seam == allowed for allowed in t.any_mr_seam if allowed is not None): + raise TdxValidationError( + f"MR_SEAM {mr_seam.hex()} does not match any allowed value" + ) + + +def _validate_min_versions(quote: QuoteV4, options: PolicyOptions) -> None: + """Validate minimum version requirements.""" + t = options.td_quote_body + + # TEE TCB SVN check + if t.minimum_tee_tcb_svn is not None: + if not _is_svn_higher_or_equal(quote.td_quote_body.tee_tcb_svn, t.minimum_tee_tcb_svn): + raise TdxValidationError( + f"TEE_TCB_SVN {quote.td_quote_body.tee_tcb_svn.hex()} is less than " + f"minimum {t.minimum_tee_tcb_svn.hex()}" + ) + + + +def validate_tdx_policy(quote: QuoteV4, options: PolicyOptions) -> None: + """ + Validate a TDX QuoteV4 against policy options. + + This is the main entry point for TDX quote policy validation. + It performs policy-based validation only - no cryptographic verification + or collateral fetching. + + Args: + quote: Parsed TDX QuoteV4 + options: Validation options + + Raises: + TdxValidationError: If any validation check fails + """ + # Validate option field lengths + _check_options_lengths(options) + + # Fixed bit validations (always run) + validate_xfam(quote.td_quote_body.xfam) + validate_td_attributes(quote.td_quote_body.td_attributes) + + # SEAM validations (required for TDX 1.0/1.5) + validate_seam_attributes(quote.td_quote_body.seam_attributes) + validate_mr_signer_seam(quote.td_quote_body.mr_signer_seam) + + # Exact byte match validations (optional based on options) + _validate_exact_byte_matches(quote, options) + + # Minimum version validations (optional based on options) + _validate_min_versions(quote, options) diff --git a/tests/test_validate_tdx.py b/tests/test_validate_tdx.py new file mode 100644 index 0000000..0c50429 --- /dev/null +++ b/tests/test_validate_tdx.py @@ -0,0 +1,421 @@ +""" +Tests for TDX attestation validation (validate_tdx.py). + +Includes: +- Unit tests for policy validation functions (validate_xfam, validate_td_attributes, etc.) +- Integration tests for verify_tdx_attestation +""" + +import base64 +import gzip +import struct +import pytest +from unittest.mock import patch, MagicMock + +# Orchestration imports (from attestation_tdx) +from tinfoil.attestation.attestation_tdx import ( + verify_tdx_attestation, + TdxAttestationError, + EXPECTED_TD_ATTRIBUTES, + EXPECTED_XFAM, + EXPECTED_MINIMUM_TEE_TCB_SVN, +) + +# Policy validation imports (from validate_tdx) +from tinfoil.attestation.validate_tdx import ( + TdxValidationError, + validate_xfam, + validate_td_attributes, + validate_seam_attributes, + validate_mr_signer_seam, + validate_tdx_policy, + PolicyOptions, + HeaderOptions, + TdQuoteBodyOptions, + XFAM_FIXED1, + XFAM_FIXED0, + TD_ATTRIBUTES_FIXED0, + TD_ATTRIBUTES_DEBUG_BIT, + TD_ATTRIBUTES_SEPT_VE_DIS, + TD_ATTRIBUTES_PKS, + TD_ATTRIBUTES_PERFMON, +) + +# ABI imports +from tinfoil.attestation.abi_tdx import INTEL_QE_VENDOR_ID +from tinfoil.attestation.attestation import ( + Document, + PredicateType, + Verification, + Measurement, +) + +# Import test fixtures from test_tdx_verify +import sys +import os +sys.path.insert(0, os.path.dirname(__file__)) +from test_tdx_verify import build_signed_quote_with_keys + + +# ============================================================================= +# Unit Tests for Policy Validation Functions +# ============================================================================= + +class TestValidateXfam: + """Test XFAM fixed bit validation.""" + + def test_valid_xfam_with_fixed1_bits(self): + """XFAM with required FIXED1 bits set passes.""" + xfam = struct.pack(' str: + """Create a base64-encoded attestation document.""" + if compress: + compressed = gzip.compress(raw_quote) + return base64.b64encode(compressed).decode() + return base64.b64encode(raw_quote).decode() + + +# ============================================================================= +# Unit Tests for verify_tdx_attestation +# ============================================================================= + +class TestVerifyTdxAttestation: + """Test verify_tdx_attestation function.""" + + def test_invalid_base64(self): + """Test that invalid base64 raises error.""" + with pytest.raises(TdxAttestationError, match="Failed to decode base64"): + verify_tdx_attestation("not valid base64!!!") + + def test_invalid_gzip(self): + """Test that invalid gzip raises error.""" + # Valid base64 but not gzip + doc = base64.b64encode(b"not gzipped").decode() + with pytest.raises(TdxAttestationError, match="Failed to decompress"): + verify_tdx_attestation(doc, is_compressed=True) + +# ============================================================================= +# Integration Tests with attestation.py +# ============================================================================= + +class TestDocumentVerify: + """Test Document.verify() with TDX format.""" + + def test_document_verify_tdx(self): + """Test that Document.verify() handles TDX format.""" + raw_quote, _, _ = build_signed_quote_with_keys() + doc = create_test_attestation_doc(raw_quote) + + document = Document( + format=PredicateType.TDX_GUEST_V2, + body=doc, + ) + + with patch('tinfoil.attestation.attestation_tdx.verify_tdx_attestation') as mock_verify: + from tinfoil.attestation.attestation_tdx import TdxAttestationResult + mock_verify.return_value = TdxAttestationResult( + quote=MagicMock(), + pck_chain=MagicMock(), + pck_extensions=MagicMock(), + collateral=None, + tcb_level=None, + measurements=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + tls_key_fp="ff" * 32, + hpke_public_key="ee" * 32, + ) + + result = document.verify() + + assert isinstance(result, Verification) + assert result.measurement.type == PredicateType.TDX_GUEST_V2 + assert len(result.measurement.registers) == 5 + assert result.public_key_fp == "ff" * 32 + + +class TestMeasurementComparison: + """Test Measurement.assert_equal() with TDX measurements.""" + + def test_tdx_same_measurements_equal(self): + """Test that identical TDX measurements are equal.""" + m1 = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + m2 = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + + # Should not raise + m1.assert_equal(m2) + + def test_tdx_different_measurements_not_equal(self): + """Test that different TDX measurements raise error.""" + from tinfoil.attestation import MeasurementMismatchError + + m1 = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + m2 = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["ff" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + + with pytest.raises(MeasurementMismatchError): + m1.assert_equal(m2) + + def test_multiplatform_vs_tdx(self): + """Test multiplatform measurement comparison with TDX.""" + # Multiplatform: [SNP_measurement, RTMR1, RTMR2] + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=["snp" * 16, "cc" * 48, "dd" * 48], + ) + + # TDX: [MRTD, RTMR0, RTMR1, RTMR2, RTMR3] + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + + # RTMR1 and RTMR2 should match + # multiplatform.registers[1] == tdx.registers[2] (RTMR1) + # multiplatform.registers[2] == tdx.registers[3] (RTMR2) + multiplatform.assert_equal(tdx) # Should not raise + + +class TestMeasurementStr: + """Test Measurement.__str__() for TDX.""" + + def test_tdx_measurement_str(self): + """Test string representation of TDX measurement.""" + m = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + + s = str(m) + assert "TDX_GUEST_V2" in s or "tdx-guest" in s + assert "mrtd=" in s + assert "rtmr0=" in s + assert "rtmr1=" in s + assert "rtmr2=" in s + assert "rtmr3=" in s + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 8fc090006bb3f16e070f407f4ff38f4d347bb519 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:14:41 -0500 Subject: [PATCH 04/19] feat: add TDX collateral fetching and validation (collateral_tdx.py) Implement Intel PCS collateral fetching with disk caching for TCB Info, QE Identity, PCK CRL, and Root CA CRL. Validates TCB status, TDX module identity, QE identity, certificate revocation (including collateral signing certs per DCAP spec), and collateral freshness via tcbEvaluationDataNumber thresholds. --- src/tinfoil/attestation/collateral_tdx.py | 2092 +++++++++++++++++++++ tests/test_collateral_tdx.py | 1892 +++++++++++++++++++ 2 files changed, 3984 insertions(+) create mode 100644 src/tinfoil/attestation/collateral_tdx.py create mode 100644 tests/test_collateral_tdx.py diff --git a/src/tinfoil/attestation/collateral_tdx.py b/src/tinfoil/attestation/collateral_tdx.py new file mode 100644 index 0000000..6a7bb50 --- /dev/null +++ b/src/tinfoil/attestation/collateral_tdx.py @@ -0,0 +1,2092 @@ +""" +Intel TDX Collateral Fetching and TCB Validation. + +This module handles fetching and validating collateral from Intel's +Provisioning Certification Service (PCS) for TDX attestation: + +- TCB Info: Contains TCB levels and status for the platform +- QE Identity: Contains Quoting Enclave identity information + +Intel PCS API: + Base URL: https://api.trustedservices.intel.com/tdx/certification/v4 + TCB Info: /tcb?fmspc={fmspc} + QE Identity: /qe/identity +""" + +import base64 +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from enum import Enum +import json +import os +import stat +from typing import List, Optional, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from .abi_tdx import QuoteV4 + from .verify_tdx import PCKCertificateChain +from urllib.parse import unquote + +from cryptography import x509 +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature +import platformdirs +import requests + +from .intel_root_ca import get_intel_root_ca +from .pck_extensions import PckExtensions, extract_pck_extensions, PckExtensionError +from .cert_utils import ( + parse_pem_chain, + certs_to_pem, + verify_intel_chain, + CertificateChainError, +) + + +# Intel PCS API base URLs +INTEL_PCS_TDX_BASE_URL = "https://api.trustedservices.intel.com/tdx/certification/v4" +INTEL_PCS_SGX_BASE_URL = "https://api.trustedservices.intel.com/sgx/certification/v4" + +# Type alias for the HTTP session used by fetch functions. +# Pass a requests.Session to reuse connections, inject retries, or mock in tests. +HttpSession = Optional[requests.Session] + + +def _http_get(url: str, timeout: float, session: HttpSession = None) -> requests.Response: + """Issue a GET request via *session* (or a one-shot ``requests.get``).""" + try: + if session is not None: + resp = session.get(url, timeout=timeout) + else: + resp = requests.get(url, timeout=timeout) + resp.raise_for_status() + return resp + except requests.RequestException as e: + raise CollateralError(f"HTTP GET {url} failed: {e}") from e + +# Cache directory for TDX collateral +_TDX_CACHE_DIR = platformdirs.user_cache_dir("tinfoil", "tinfoil") + + +class CollateralError(Exception): + """Raised when collateral fetching or validation fails.""" + pass + + +class TcbStatus(str, Enum): + """TCB status values from Intel PCS.""" + UP_TO_DATE = "UpToDate" + SW_HARDENING_NEEDED = "SWHardeningNeeded" + CONFIGURATION_NEEDED = "ConfigurationNeeded" + CONFIGURATION_AND_SW_HARDENING_NEEDED = "ConfigurationAndSWHardeningNeeded" + OUT_OF_DATE = "OutOfDate" + OUT_OF_DATE_CONFIGURATION_NEEDED = "OutOfDateConfigurationNeeded" + REVOKED = "Revoked" + + +# ============================================================================= +# CRL Data Structures +# ============================================================================= + +@dataclass +class PckCrl: + """ + PCK Certificate Revocation List from Intel PCS. + + Contains the parsed CRL and metadata for caching. + """ + crl: x509.CertificateRevocationList + ca_type: str # "platform" or "processor" + next_update: datetime + + +@dataclass +class RootCrl: + """ + Intel SGX Root CA Certificate Revocation List. + + Used to check if intermediate CA certificates have been revoked. + """ + crl: x509.CertificateRevocationList + next_update: datetime + + +# ============================================================================= +# TCB Info Data Structures +# ============================================================================= + +@dataclass +class TcbComponent: + """A single TCB component with SVN and metadata.""" + svn: int + category: str = "" + type: str = "" + + +@dataclass +class Tcb: + """TCB level data containing SVN components.""" + sgx_tcb_components: List[TcbComponent] # 16 SGX components + pce_svn: int + tdx_tcb_components: List[TcbComponent] # 16 TDX components + isv_svn: Optional[int] = None # ISV SVN for QE Identity TCB levels + + +@dataclass +class TcbLevel: + """A TCB level with status and advisory IDs.""" + tcb: Tcb + tcb_date: str + tcb_status: TcbStatus + advisory_ids: List[str] + + +@dataclass +class TdxModule: + """TDX module identity information.""" + mrsigner: bytes # 48 bytes + attributes: bytes + attributes_mask: bytes + + +@dataclass +class TdxModuleIdentity: + """TDX module identity with associated TCB levels.""" + id: str # e.g., "TDX_01", "TDX_03" + mrsigner: bytes + attributes: bytes + attributes_mask: bytes + tcb_levels: List[TcbLevel] + + +@dataclass +class TcbInfo: + """TCB Info structure from Intel PCS.""" + id: str # Must be "TDX" + version: int # Must be 3 + issue_date: datetime + next_update: datetime + fmspc: str + pce_id: str + tcb_type: int + tcb_evaluation_data_number: int + tdx_module: Optional[TdxModule] + tdx_module_identities: List[TdxModuleIdentity] + tcb_levels: List[TcbLevel] + + +@dataclass +class TdxTcbInfo: + """Top-level TCB Info response with signature.""" + tcb_info: TcbInfo + signature: str + + +# ============================================================================= +# QE Identity Data Structures +# ============================================================================= + +@dataclass +class EnclaveIdentity: + """Quoting Enclave identity information.""" + id: str # Must be "TD_QE" + version: int # Must be 2 + issue_date: datetime + next_update: datetime + tcb_evaluation_data_number: int + miscselect: bytes # 4 bytes + miscselect_mask: bytes # 4 bytes + attributes: bytes # 16 bytes + attributes_mask: bytes # 16 bytes + mrsigner: bytes # 32 bytes + isv_prod_id: int + tcb_levels: List[TcbLevel] + + +@dataclass +class QeIdentity: + """Top-level QE Identity response with signature.""" + enclave_identity: EnclaveIdentity + signature: str + + +# ============================================================================= +# Collateral Container +# ============================================================================= + +@dataclass +class TdxCollateral: + """Container for all TDX collateral data.""" + tcb_info: TdxTcbInfo + qe_identity: QeIdentity + tcb_info_raw: bytes # Raw JSON for signature verification + qe_identity_raw: bytes # Raw JSON for signature verification + pck_crl: Optional[PckCrl] = None # PCK CRL for revocation checking + root_crl: Optional[RootCrl] = None # Root CA CRL for intermediate revocation checking + tcb_info_issuer_chain: Optional[List[x509.Certificate]] = None # TCB Info signing chain for CRL checking + qe_identity_issuer_chain: Optional[List[x509.Certificate]] = None # QE Identity signing chain for CRL checking + + +# ============================================================================= +# Collateral Cache Helpers +# ============================================================================= + +# Cache directory permissions (owner-only) +_CACHE_DIR_MODE = stat.S_IRWXU # 0700 +_CACHE_FILE_MODE = stat.S_IRUSR | stat.S_IWUSR # 0600 + + +@dataclass +class CacheEntry: + """ + Cache entry containing body and issuer chain for signature verification. + + This allows re-verification of signatures on cache hits, not just on fetch. + """ + body: bytes # Raw response body (JSON for TCB/QE, DER for CRLs) + issuer_chain_pem: Optional[str] = None # PEM-encoded issuer cert chain + + +def _ensure_cache_dir() -> bool: + """ + Ensure cache directory exists with secure permissions (0700). + + Returns: + True if directory exists/was created, False on failure + """ + try: + os.makedirs(_TDX_CACHE_DIR, mode=_CACHE_DIR_MODE, exist_ok=True) + # Tighten permissions if directory already existed with looser perms + os.chmod(_TDX_CACHE_DIR, _CACHE_DIR_MODE) + return True + except OSError: + return False + + +def _get_tcb_info_cache_path(fmspc: str) -> str: + """Get cache file path for TCB Info (keyed by FMSPC).""" + return os.path.join(_TDX_CACHE_DIR, f"tdx_tcb_info_{fmspc.lower()}.json") + + +def _get_qe_identity_cache_path() -> str: + """Get cache file path for QE Identity (global, not FMSPC-specific).""" + return os.path.join(_TDX_CACHE_DIR, "tdx_qe_identity.json") + + +def _read_cache(cache_path: str) -> Optional[CacheEntry]: + """ + Read cached collateral from disk. + + Returns: + CacheEntry with body and optional issuer chain, or None on failure + """ + if not os.path.isfile(cache_path): + return None + try: + with open(cache_path, "rb") as f: + data = json.loads(f.read().decode("utf-8")) + return CacheEntry( + body=base64.b64decode(data["body"]), + issuer_chain_pem=data.get("issuer_chain_pem"), + ) + except (OSError, json.JSONDecodeError, KeyError, ValueError): + return None + + +def _write_cache(cache_path: str, entry: CacheEntry) -> None: + """ + Write collateral to disk cache atomically with secure permissions. + + Uses write-to-temp + rename pattern to prevent partial writes. + Sets file permissions to 0600 (owner read/write only). + """ + if not _ensure_cache_dir(): + return + + data = { + "body": base64.b64encode(entry.body).decode("ascii"), + } + if entry.issuer_chain_pem is not None: + data["issuer_chain_pem"] = entry.issuer_chain_pem + + tmp_path = cache_path + ".tmp" + try: + # Write to temp file + fd = os.open(tmp_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, _CACHE_FILE_MODE) + try: + os.write(fd, json.dumps(data).encode("utf-8")) + finally: + os.close(fd) + # Atomic rename + os.replace(tmp_path, cache_path) + except OSError: + # Clean up temp file on failure + try: + os.unlink(tmp_path) + except OSError: + pass + + +def _is_fresh(next_update: Optional[datetime]) -> bool: + """Check if a collateral item is still fresh (now < next_update).""" + if next_update is None: + return False + return datetime.now(timezone.utc) < next_update + + +def _is_tcb_info_fresh(tcb_info: TdxTcbInfo) -> bool: + """Check if TCB Info is still fresh (not expired).""" + return _is_fresh(tcb_info.tcb_info.next_update) + + +def _is_qe_identity_fresh(qe_identity: QeIdentity) -> bool: + """Check if QE Identity is still fresh (not expired).""" + return _is_fresh(qe_identity.enclave_identity.next_update) + + +def _get_crl_cache_path(ca_type: str) -> str: + """Get cache file path for PCK CRL (keyed by CA type).""" + return os.path.join(_TDX_CACHE_DIR, f"tdx_pck_crl_{ca_type.lower()}.json") + + +def _get_root_crl_cache_path() -> str: + """Get cache file path for Intel SGX Root CA CRL.""" + return os.path.join(_TDX_CACHE_DIR, "intel_sgx_root_ca_crl.json") + + +def _is_crl_fresh(crl: x509.CertificateRevocationList) -> bool: + """Check if CRL is still fresh (not expired).""" + return _is_fresh(crl.next_update_utc) + + +# ============================================================================= +# Intel-defined field sizes (hex character counts) +# ============================================================================= + +TDX_MRSIGNER_SIZE = 48 # TDX module MRSIGNER: 48 bytes +QE_MRSIGNER_SIZE = 32 # QE enclave MRSIGNER: 32 bytes +QE_ATTRIBUTES_SIZE = 16 # QE enclave attributes: 16 bytes +MISCSELECT_SIZE = 4 # MISCSELECT field: 4 bytes +TCB_COMPONENT_COUNT = 16 # Number of TCB SVN components +ECDSA_P256_COMPONENT_SIZE = 32 # 32 bytes per R or S component +ECDSA_SIGNATURE_SIZE = 64 # R || S (32 + 32 bytes) + + +# ============================================================================= +# Parsing Functions +# ============================================================================= + +def _parse_hex_bytes(hex_str: str) -> bytes: + """Parse hex string to bytes.""" + return bytes.fromhex(hex_str) + + +def _parse_datetime(dt_str: str) -> datetime: + """Parse ISO datetime string to datetime object.""" + # Handle format: "2025-12-17T06:24:56Z" + return datetime.fromisoformat(dt_str.replace("Z", "+00:00")) + + +def _parse_tcb_component(data: dict) -> TcbComponent: + """Parse a TCB component from JSON.""" + return TcbComponent( + svn=data.get("svn", 0), + category=data.get("category", ""), + type=data.get("type", ""), + ) + + +def _parse_tcb(data: dict) -> Tcb: + """Parse TCB from JSON.""" + sgx_components = [ + _parse_tcb_component(c) for c in data.get("sgxtcbcomponents", []) + ] + tdx_components = [ + _parse_tcb_component(c) for c in data.get("tdxtcbcomponents", []) + ] + + # Pad to 16 components if needed + while len(sgx_components) < TCB_COMPONENT_COUNT: + sgx_components.append(TcbComponent(svn=0)) + while len(tdx_components) < TCB_COMPONENT_COUNT: + tdx_components.append(TcbComponent(svn=0)) + + # Capture isvsvn if present (used by QE Identity TCB levels) + isv_svn = data.get("isvsvn") + + return Tcb( + sgx_tcb_components=sgx_components[:TCB_COMPONENT_COUNT], + pce_svn=data.get("pcesvn", 0), + tdx_tcb_components=tdx_components[:TCB_COMPONENT_COUNT], + isv_svn=isv_svn, + ) + + +def _parse_tcb_level(data: dict) -> TcbLevel: + """Parse a TCB level from JSON.""" + tcb_data = data.get("tcb", {}) + + # Handle simple TCB (just isvsvn) vs full TCB + if "sgxtcbcomponents" in tcb_data: + tcb = _parse_tcb(tcb_data) + else: + # Simple TCB with just isvsvn (used by QE Identity) + # Capture the isvsvn value for QE TCB level matching + isv_svn = tcb_data.get("isvsvn") + tcb = Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=isv_svn, + ) + + return TcbLevel( + tcb=tcb, + tcb_date=data.get("tcbDate", ""), + tcb_status=TcbStatus(data.get("tcbStatus", "UpToDate")), + advisory_ids=data.get("advisoryIDs", []), + ) + + +def _parse_tdx_module(data: dict) -> TdxModule: + """Parse TDX module from JSON.""" + return TdxModule( + mrsigner=_parse_hex_bytes(data.get("mrsigner", "00" * TDX_MRSIGNER_SIZE)), + attributes=_parse_hex_bytes(data.get("attributes", "")), + attributes_mask=_parse_hex_bytes(data.get("attributesMask", "")), + ) + + +def _parse_tdx_module_identity(data: dict) -> TdxModuleIdentity: + """Parse TDX module identity from JSON.""" + return TdxModuleIdentity( + id=data.get("id", ""), + mrsigner=_parse_hex_bytes(data.get("mrsigner", "00" * TDX_MRSIGNER_SIZE)), + attributes=_parse_hex_bytes(data.get("attributes", "")), + attributes_mask=_parse_hex_bytes(data.get("attributesMask", "")), + tcb_levels=[_parse_tcb_level(l) for l in data.get("tcbLevels", [])], + ) + + +def _parse_tcb_info(data: dict) -> TcbInfo: + """Parse TCB Info from JSON.""" + tdx_module = None + if "tdxModule" in data: + tdx_module = _parse_tdx_module(data["tdxModule"]) + + return TcbInfo( + id=data.get("id", ""), + version=data.get("version", 0), + issue_date=_parse_datetime(data.get("issueDate", "1970-01-01T00:00:00Z")), + next_update=_parse_datetime(data.get("nextUpdate", "1970-01-01T00:00:00Z")), + fmspc=data.get("fmspc", ""), + pce_id=data.get("pceId", ""), + tcb_type=data.get("tcbType", 0), + tcb_evaluation_data_number=data.get("tcbEvaluationDataNumber", 0), + tdx_module=tdx_module, + tdx_module_identities=[ + _parse_tdx_module_identity(m) for m in data.get("tdxModuleIdentities", []) + ], + tcb_levels=[_parse_tcb_level(l) for l in data.get("tcbLevels", [])], + ) + + +def _parse_enclave_identity(data: dict) -> EnclaveIdentity: + """Parse Enclave Identity from JSON.""" + return EnclaveIdentity( + id=data.get("id", ""), + version=data.get("version", 0), + issue_date=_parse_datetime(data.get("issueDate", "1970-01-01T00:00:00Z")), + next_update=_parse_datetime(data.get("nextUpdate", "1970-01-01T00:00:00Z")), + tcb_evaluation_data_number=data.get("tcbEvaluationDataNumber", 0), + miscselect=_parse_hex_bytes(data.get("miscselect", "00" * MISCSELECT_SIZE)), + miscselect_mask=_parse_hex_bytes(data.get("miscselectMask", "00" * MISCSELECT_SIZE)), + attributes=_parse_hex_bytes(data.get("attributes", "00" * QE_ATTRIBUTES_SIZE)), + attributes_mask=_parse_hex_bytes(data.get("attributesMask", "00" * QE_ATTRIBUTES_SIZE)), + mrsigner=_parse_hex_bytes(data.get("mrsigner", "00" * QE_MRSIGNER_SIZE)), + isv_prod_id=data.get("isvprodid", 0), + tcb_levels=[_parse_tcb_level(l) for l in data.get("tcbLevels", [])], + ) + + +def parse_tcb_info_response(response_bytes: bytes) -> TdxTcbInfo: + """ + Parse TCB Info response from Intel PCS. + + Args: + response_bytes: Raw JSON response bytes + + Returns: + Parsed TdxTcbInfo + + Raises: + CollateralError: If parsing fails + """ + try: + data = json.loads(response_bytes) + except json.JSONDecodeError as e: + raise CollateralError(f"Failed to parse TCB Info JSON: {e}") + + tcb_info = _parse_tcb_info(data.get("tcbInfo", {})) + + # Validate required fields + if tcb_info.id != "TDX": + raise CollateralError(f"TCB Info ID must be 'TDX', got '{tcb_info.id}'") + if tcb_info.version != 3: + raise CollateralError(f"TCB Info version must be 3, got {tcb_info.version}") + + return TdxTcbInfo( + tcb_info=tcb_info, + signature=data.get("signature", ""), + ) + + +def parse_qe_identity_response(response_bytes: bytes) -> QeIdentity: + """ + Parse QE Identity response from Intel PCS. + + Args: + response_bytes: Raw JSON response bytes + + Returns: + Parsed QeIdentity + + Raises: + CollateralError: If parsing fails + """ + try: + data = json.loads(response_bytes) + except json.JSONDecodeError as e: + raise CollateralError(f"Failed to parse QE Identity JSON: {e}") + + enclave_identity = _parse_enclave_identity(data.get("enclaveIdentity", {})) + + # Validate required fields + if enclave_identity.id != "TD_QE": + raise CollateralError( + f"QE Identity ID must be 'TD_QE', got '{enclave_identity.id}'" + ) + if enclave_identity.version != 2: + raise CollateralError( + f"QE Identity version must be 2, got {enclave_identity.version}" + ) + + return QeIdentity( + enclave_identity=enclave_identity, + signature=data.get("signature", ""), + ) + + +# ============================================================================= +# Collateral Signature Verification +# ============================================================================= + +def _parse_issuer_chain_header(header_value: str) -> List[x509.Certificate]: + """ + Parse the issuer certificate chain from a PCS response header. + + Intel PCS returns the chain as URL-encoded concatenated PEM certificates + in the TCB-Info-Issuer-Chain or SGX-Enclave-Identity-Issuer-Chain header. + + Args: + header_value: URL-encoded PEM certificate chain + + Returns: + List of parsed certificates (signing cert first, root last) + + Raises: + CollateralError: If parsing fails + """ + # URL-decode the header value + pem_data = unquote(header_value).encode('utf-8') + + try: + certs = parse_pem_chain(pem_data) + except CertificateChainError as e: + raise CollateralError(f"Failed to parse issuer chain certificate: {e}") + + if len(certs) < 2: + raise CollateralError( + f"Issuer chain should contain at least 2 certificates, got {len(certs)}" + ) + + return certs + + +def _verify_collateral_signature( + json_bytes: bytes, + json_key: str, + signature_hex: str, + signing_cert: x509.Certificate, + data_name: str, +) -> None: + """ + Verify the signature over collateral JSON data. + + Intel PCS signs the raw JSON string of the inner object (tcbInfo or + enclaveIdentity), not the outer wrapper. The signature is ECDSA P-256 + over SHA256 of the raw JSON bytes. + + Args: + json_bytes: Raw response bytes (full JSON) + json_key: Key name to extract for signing ("tcbInfo" or "enclaveIdentity") + signature_hex: Hex-encoded signature from response + signing_cert: Signing certificate (leaf of issuer chain) + data_name: Human-readable name for error messages + + Raises: + CollateralError: If signature verification fails + """ + # Extract the raw JSON string for the signed object + # We need to find the exact bytes of the inner JSON object as it appears + # in the response (including whitespace), since the signature is over + # the exact byte representation. + # + # The format is: {"tcbInfo":{...},"signature":"..."} + # We need to extract the exact "{...}" for tcbInfo + try: + json_str = json_bytes.decode('utf-8') + except UnicodeDecodeError as e: + raise CollateralError(f"Failed to decode {data_name} JSON: {e}") + + # Find the start of the inner object. + # Only match at the top-level (must be preceded by '{' or ',' ignoring whitespace). + key_pattern = f'"{json_key}":' + key_pos = json_str.find(key_pattern) + if key_pos == -1: + raise CollateralError(f"{data_name} JSON does not contain '{json_key}' key") + + # Verify the match is at the top level (preceded by '{' or ',' outside strings) + prefix = json_str[:key_pos].rstrip() + if not prefix or prefix[-1] not in ('{', ','): + raise CollateralError( + f"{data_name} JSON key '{json_key}' found in unexpected position" + ) + + # Find the start of the object value (skip whitespace after colon) + obj_start = key_pos + len(key_pattern) + while obj_start < len(json_str) and json_str[obj_start] in ' \t\n\r': + obj_start += 1 + + if obj_start >= len(json_str) or json_str[obj_start] != '{': + raise CollateralError(f"{data_name} '{json_key}' is not an object") + + # Use the stdlib JSON decoder to find the extent of the nested object. + # raw_decode parses from the given index and returns the end position, + # correctly handling string escaping, nesting, and all JSON edge cases. + try: + decoder = json.JSONDecoder() + _, obj_end = decoder.raw_decode(json_str, obj_start) + except json.JSONDecodeError as e: + raise CollateralError(f"{data_name} JSON has malformed '{json_key}' object: {e}") from e + + # Extract the signed data (exact bytes as they appear in the response) + signed_json = json_str[obj_start:obj_end].encode('utf-8') + + # Parse signature (hex-encoded raw R||S format, 64 bytes = 128 hex chars) + try: + sig_bytes = bytes.fromhex(signature_hex) + except ValueError as e: + raise CollateralError(f"{data_name} signature is not valid hex: {e}") + + if len(sig_bytes) != ECDSA_SIGNATURE_SIZE: + raise CollateralError( + f"{data_name} signature is {len(sig_bytes)} bytes, expected {ECDSA_SIGNATURE_SIZE}" + ) + + # Convert R||S to DER format + r = int.from_bytes(sig_bytes[0:ECDSA_P256_COMPONENT_SIZE], byteorder='big') + s = int.from_bytes(sig_bytes[ECDSA_P256_COMPONENT_SIZE:ECDSA_SIGNATURE_SIZE], byteorder='big') + signature_der = encode_dss_signature(r, s) + + # Verify ECDSA signature + try: + public_key = signing_cert.public_key() + public_key.verify(signature_der, signed_json, ec.ECDSA(hashes.SHA256())) + except InvalidSignature: + raise CollateralError( + f"{data_name} signature verification failed: signature does not match content" + ) + + +def verify_tcb_info_signature( + response_bytes: bytes, + tcb_info: TdxTcbInfo, + issuer_chain: List[x509.Certificate], +) -> None: + """ + Verify TCB Info signature against the issuer certificate chain. + + Args: + response_bytes: Raw TCB Info response bytes + tcb_info: Parsed TCB Info + issuer_chain: Issuer certificate chain from response header + + Raises: + CollateralError: If verification fails + """ + # Verify the issuer chain + try: + verify_intel_chain(issuer_chain, "TCB Info issuer chain") + except CertificateChainError as e: + raise CollateralError(str(e)) + + # Verify signature over tcbInfo JSON + _verify_collateral_signature( + json_bytes=response_bytes, + json_key="tcbInfo", + signature_hex=tcb_info.signature, + signing_cert=issuer_chain[0], + data_name="TCB Info", + ) + + +def verify_qe_identity_signature( + response_bytes: bytes, + qe_identity: QeIdentity, + issuer_chain: List[x509.Certificate], +) -> None: + """ + Verify QE Identity signature against the issuer certificate chain. + + Args: + response_bytes: Raw QE Identity response bytes + qe_identity: Parsed QE Identity + issuer_chain: Issuer certificate chain from response header + + Raises: + CollateralError: If verification fails + """ + # Verify the issuer chain + try: + verify_intel_chain(issuer_chain, "QE Identity issuer chain") + except CertificateChainError as e: + raise CollateralError(str(e)) + + # Verify signature over enclaveIdentity JSON + _verify_collateral_signature( + json_bytes=response_bytes, + json_key="enclaveIdentity", + signature_hex=qe_identity.signature, + signing_cert=issuer_chain[0], + data_name="QE Identity", + ) + + +# ============================================================================= +# Collateral Fetching +# ============================================================================= + +def fetch_tcb_info( + fmspc: str, timeout: float = 30.0, session: HttpSession = None, +) -> Tuple[TdxTcbInfo, bytes, List[x509.Certificate]]: + """ + Fetch TCB Info from Intel PCS, with caching. + + Cached TCB Info is stored on disk and reused until the next_update + timestamp expires. Signature is re-verified on every cache hit. + + Args: + fmspc: FMSPC value from PCK certificate (6 bytes hex) + timeout: Request timeout in seconds + session: Optional requests.Session for connection reuse / testing + + Returns: + Tuple of (parsed TdxTcbInfo, raw response bytes, issuer certificate chain) + + Raises: + CollateralError: If fetching, parsing, or signature verification fails + """ + cache_path = _get_tcb_info_cache_path(fmspc) + + # Try cache first + cached_entry = _read_cache(cache_path) + if cached_entry is not None and cached_entry.issuer_chain_pem is not None: + try: + cached_tcb_info = parse_tcb_info_response(cached_entry.body) + if _is_tcb_info_fresh(cached_tcb_info): + # Re-verify signature on cache hit + issuer_chain = parse_pem_chain(cached_entry.issuer_chain_pem.encode("utf-8")) + verify_tcb_info_signature( + cached_entry.body, cached_tcb_info, issuer_chain + ) + return cached_tcb_info, cached_entry.body, issuer_chain + except (CollateralError, CertificateChainError, ValueError, KeyError, OSError): + pass + + # Cache miss or stale - fetch from Intel PCS + url = f"{INTEL_PCS_TDX_BASE_URL}/tcb?fmspc={fmspc}" + response = _http_get(url, timeout, session) + + raw_bytes = response.content + tcb_info = parse_tcb_info_response(raw_bytes) + + # Extract and verify issuer certificate chain from response header + issuer_chain_header = response.headers.get("TCB-Info-Issuer-Chain") + if not issuer_chain_header: + raise CollateralError( + "TCB Info response missing TCB-Info-Issuer-Chain header" + ) + + issuer_chain = _parse_issuer_chain_header(issuer_chain_header) + verify_tcb_info_signature(raw_bytes, tcb_info, issuer_chain) + + # Write to cache with issuer chain for re-verification + cache_entry = CacheEntry( + body=raw_bytes, + issuer_chain_pem=certs_to_pem(issuer_chain), + ) + _write_cache(cache_path, cache_entry) + + return tcb_info, raw_bytes, issuer_chain + + +def fetch_qe_identity( + timeout: float = 30.0, session: HttpSession = None, +) -> Tuple[QeIdentity, bytes, List[x509.Certificate]]: + """ + Fetch QE Identity from Intel PCS, with caching. + + Cached QE Identity is stored on disk and reused until the next_update + timestamp expires. Signature is re-verified on every cache hit. + + Note: QE Identity is global (not FMSPC-specific) so there's only one + cache file shared across all platforms. + + Args: + timeout: Request timeout in seconds + session: Optional requests.Session for connection reuse / testing + + Returns: + Tuple of (parsed QeIdentity, raw response bytes, issuer certificate chain) + + Raises: + CollateralError: If fetching, parsing, or signature verification fails + """ + cache_path = _get_qe_identity_cache_path() + + # Try cache first + cached_entry = _read_cache(cache_path) + if cached_entry is not None and cached_entry.issuer_chain_pem is not None: + try: + cached_qe_identity = parse_qe_identity_response(cached_entry.body) + if _is_qe_identity_fresh(cached_qe_identity): + # Re-verify signature on cache hit + issuer_chain = parse_pem_chain(cached_entry.issuer_chain_pem.encode("utf-8")) + verify_qe_identity_signature( + cached_entry.body, cached_qe_identity, issuer_chain + ) + return cached_qe_identity, cached_entry.body, issuer_chain + except (CollateralError, CertificateChainError, ValueError, KeyError, OSError): + pass + + # Cache miss or stale - fetch from Intel PCS + url = f"{INTEL_PCS_TDX_BASE_URL}/qe/identity" + response = _http_get(url, timeout, session) + + raw_bytes = response.content + qe_identity = parse_qe_identity_response(raw_bytes) + + # Extract and verify issuer certificate chain from response header + issuer_chain_header = response.headers.get("SGX-Enclave-Identity-Issuer-Chain") + if not issuer_chain_header: + raise CollateralError( + "QE Identity response missing SGX-Enclave-Identity-Issuer-Chain header" + ) + + issuer_chain = _parse_issuer_chain_header(issuer_chain_header) + verify_qe_identity_signature(raw_bytes, qe_identity, issuer_chain) + + # Write to cache with issuer chain for re-verification + cache_entry = CacheEntry( + body=raw_bytes, + issuer_chain_pem=certs_to_pem(issuer_chain), + ) + _write_cache(cache_path, cache_entry) + + return qe_identity, raw_bytes, issuer_chain + + +def _verify_crl_signature( + crl: x509.CertificateRevocationList, + issuer_chain: List[x509.Certificate], + ca_type: str, +) -> None: + """ + Verify the CRL signature against the issuer certificate chain. + + Args: + crl: Parsed CRL + issuer_chain: Issuer certificate chain from response header + ca_type: CA type for error messages + + Raises: + CollateralError: If verification fails + """ + # Verify the issuer chain first + try: + verify_intel_chain(issuer_chain, f"PCK CRL ({ca_type}) issuer chain") + except CertificateChainError as e: + raise CollateralError(str(e)) + + # Verify CRL signature using the signing cert (first in chain) + signing_cert = issuer_chain[0] + hash_algo = crl.signature_hash_algorithm + if hash_algo is None: + raise CollateralError(f"PCK CRL ({ca_type}) has no signature hash algorithm") + try: + signing_cert.public_key().verify( + crl.signature, + crl.tbs_certlist_bytes, + ec.ECDSA(hash_algo), + ) + except InvalidSignature: + raise CollateralError( + f"PCK CRL ({ca_type}) signature verification failed" + ) + + +def fetch_pck_crl( + ca_type: str, timeout: float = 30.0, session: HttpSession = None, +) -> PckCrl: + """ + Fetch PCK CRL from Intel PCS, with caching. + + The CRL is fetched from the SGX certification API (shared with TDX). + Cached CRL is stored on disk and reused until next_update expires. + Signature is re-verified on every cache hit. + + Args: + ca_type: CA type - "platform" or "processor" + timeout: Request timeout in seconds + + Returns: + Parsed PckCrl + + Raises: + CollateralError: If fetching, parsing, or signature verification fails + """ + if ca_type not in ("platform", "processor"): + raise CollateralError(f"Invalid CA type: {ca_type}. Must be 'platform' or 'processor'") + + cache_path = _get_crl_cache_path(ca_type) + + # Try cache first + cached_entry = _read_cache(cache_path) + if cached_entry is not None and cached_entry.issuer_chain_pem is not None: + try: + cached_crl = x509.load_pem_x509_crl(cached_entry.body) + if _is_crl_fresh(cached_crl): + # Re-verify signature on cache hit + issuer_chain = parse_pem_chain(cached_entry.issuer_chain_pem.encode("utf-8")) + _verify_crl_signature(cached_crl, issuer_chain, ca_type) + next_update = cached_crl.next_update_utc + if next_update is None: + raise CollateralError(f"PCK CRL ({ca_type}) is missing next_update field") + return PckCrl(crl=cached_crl, ca_type=ca_type, next_update=next_update) + except (CollateralError, CertificateChainError, ValueError, KeyError, OSError): + pass + + # Cache miss or stale - fetch from Intel PCS + # CRL endpoint is under the SGX API (not TDX-specific) + url = f"{INTEL_PCS_SGX_BASE_URL}/pckcrl?ca={ca_type}" + response = _http_get(url, timeout, session) + + raw_bytes = response.content + + # Parse the PEM-encoded CRL + try: + crl = x509.load_pem_x509_crl(raw_bytes) + except ValueError as e: + raise CollateralError(f"Failed to parse PCK CRL ({ca_type}): {e}") from e + + # Extract and verify issuer certificate chain from response header + issuer_chain_header = response.headers.get("SGX-PCK-CRL-Issuer-Chain") + if not issuer_chain_header: + raise CollateralError( + f"PCK CRL ({ca_type}) response missing SGX-PCK-CRL-Issuer-Chain header" + ) + + issuer_chain = _parse_issuer_chain_header(issuer_chain_header) + _verify_crl_signature(crl, issuer_chain, ca_type) + + # Write to cache with issuer chain for re-verification + cache_entry = CacheEntry( + body=raw_bytes, + issuer_chain_pem=certs_to_pem(issuer_chain), + ) + _write_cache(cache_path, cache_entry) + + next_update = crl.next_update_utc + if next_update is None: + raise CollateralError(f"PCK CRL ({ca_type}) is missing next_update field") + + return PckCrl(crl=crl, ca_type=ca_type, next_update=next_update) + + +# Intel SGX Root CA CRL URL (from the certificate's CRL Distribution Point) +INTEL_SGX_ROOT_CA_CRL_URL = "https://certificates.trustedservices.intel.com/IntelSGXRootCA.der" + + +def _verify_root_crl_signature(crl: x509.CertificateRevocationList) -> None: + """ + Verify the Root CA CRL signature against Intel SGX Root CA. + + Args: + crl: Parsed CRL + + Raises: + CollateralError: If signature verification fails + """ + intel_root = get_intel_root_ca() + hash_algo = crl.signature_hash_algorithm + if hash_algo is None: + raise CollateralError("Intel SGX Root CA CRL has no signature hash algorithm") + try: + intel_root.public_key().verify( + crl.signature, + crl.tbs_certlist_bytes, + ec.ECDSA(hash_algo), + ) + except InvalidSignature: + raise CollateralError("Intel SGX Root CA CRL signature verification failed") + + +def fetch_root_ca_crl(timeout: float = 30.0, session: HttpSession = None) -> RootCrl: + """ + Fetch Intel SGX Root CA CRL, with caching. + + This CRL lists revoked intermediate CA certificates (Platform CA, Processor CA). + Used to verify that the PCK certificate chain's intermediate CA is not revoked. + Signature is re-verified on every cache hit. + + Args: + timeout: Request timeout in seconds + + Returns: + Parsed RootCrl + + Raises: + CollateralError: If fetching, parsing, or signature verification fails + """ + cache_path = _get_root_crl_cache_path() + + # Try cache first + cached_entry = _read_cache(cache_path) + if cached_entry is not None: + try: + cached_crl = x509.load_der_x509_crl(cached_entry.body) + if _is_crl_fresh(cached_crl): + # Re-verify signature on cache hit + _verify_root_crl_signature(cached_crl) + next_update = cached_crl.next_update_utc + if next_update is None: + raise CollateralError("Intel SGX Root CA CRL is missing next_update field") + return RootCrl(crl=cached_crl, next_update=next_update) + except (CollateralError, CertificateChainError, ValueError, KeyError, OSError): + pass + + # Cache miss or stale - fetch from Intel + response = _http_get(INTEL_SGX_ROOT_CA_CRL_URL, timeout, session) + + raw_bytes = response.content + + # Parse the DER-encoded CRL + try: + crl = x509.load_der_x509_crl(raw_bytes) + except Exception as e: + raise CollateralError(f"Failed to parse Intel SGX Root CA CRL: {e}") + + # Verify CRL is signed by the Intel SGX Root CA + _verify_root_crl_signature(crl) + + # Write to cache (no issuer chain needed - verified against embedded root) + cache_entry = CacheEntry(body=raw_bytes) + _write_cache(cache_path, cache_entry) + + next_update = crl.next_update_utc + if next_update is None: + raise CollateralError("Intel SGX Root CA CRL is missing next_update field") + + return RootCrl(crl=crl, next_update=next_update) + + +def _determine_pck_ca_type(pck_cert: x509.Certificate) -> str: + """ + Determine which CA issued the PCK certificate. + + The PCK certificate issuer CN indicates the CA type: + - "Intel SGX PCK Platform CA" -> "platform" + - "Intel SGX PCK Processor CA" -> "processor" + + Args: + pck_cert: PCK certificate from the quote + + Returns: + CA type string ("platform" or "processor") + + Raises: + CollateralError: If CA type cannot be determined + """ + try: + issuer = pck_cert.issuer + for attr in issuer: + if attr.oid == x509.oid.NameOID.COMMON_NAME: + cn = attr.value + if "Platform" in cn: + return "platform" + elif "Processor" in cn: + return "processor" + raise CollateralError( + f"Could not determine PCK CA type from issuer: {issuer}" + ) + except Exception as e: + raise CollateralError(f"Failed to determine PCK CA type: {e}") + + +def fetch_collateral( + pck_extensions: PckExtensions, + pck_cert: x509.Certificate, + timeout: float = 30.0, + session: HttpSession = None, +) -> TdxCollateral: + """ + Fetch all required collateral from Intel PCS. + + Args: + pck_extensions: PCK certificate extensions containing FMSPC + pck_cert: PCK certificate (needed for CRL fetching) + timeout: Request timeout in seconds + session: Optional requests.Session for connection reuse / testing + + Returns: + TdxCollateral containing all fetched data + + Raises: + CollateralError: If fetching fails + """ + tcb_info, tcb_info_raw, tcb_info_issuer_chain = fetch_tcb_info(pck_extensions.fmspc, timeout, session) + qe_identity, qe_identity_raw, qe_identity_issuer_chain = fetch_qe_identity(timeout, session) + + # Fetch CRLs for revocation checking + ca_type = _determine_pck_ca_type(pck_cert) + pck_crl = fetch_pck_crl(ca_type, timeout, session) + root_crl = fetch_root_ca_crl(timeout, session) + + return TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=tcb_info_raw, + qe_identity_raw=qe_identity_raw, + pck_crl=pck_crl, + root_crl=root_crl, + tcb_info_issuer_chain=tcb_info_issuer_chain, + qe_identity_issuer_chain=qe_identity_issuer_chain, + ) + + +# ============================================================================= +# TCB Evaluation Data Number Calculation +# ============================================================================= + +@dataclass +class TcbEvalNumber: + """A single TCB evaluation data number with its dates.""" + tcb_evaluation_data_number: int + tcb_recovery_event_date: datetime + tcb_date: datetime + + +@dataclass +class TcbEvaluationDataNumbers: + """Response from the tcbevaluationdatanumbers endpoint.""" + id: str # "TDX" + version: int + issue_date: datetime + next_update: datetime + tcb_eval_numbers: List[TcbEvalNumber] + signature: str + + +def _parse_tcb_eval_numbers_response(response_bytes: bytes) -> TcbEvaluationDataNumbers: + """ + Parse the tcbevaluationdatanumbers response from Intel PCS. + + Args: + response_bytes: Raw JSON response bytes + + Returns: + Parsed TcbEvaluationDataNumbers + + Raises: + CollateralError: If parsing fails + """ + try: + data = json.loads(response_bytes) + except json.JSONDecodeError as e: + raise CollateralError(f"Failed to parse TCB evaluation data numbers JSON: {e}") + + inner = data.get("tcbEvaluationDataNumbers", {}) + + tcb_eval_numbers = [] + for item in inner.get("tcbEvalNumbers", []): + tcb_eval_numbers.append(TcbEvalNumber( + tcb_evaluation_data_number=item.get("tcbEvaluationDataNumber", 0), + tcb_recovery_event_date=_parse_datetime(item.get("tcbRecoveryEventDate", "1970-01-01T00:00:00Z")), + tcb_date=_parse_datetime(item.get("tcbDate", "1970-01-01T00:00:00Z")), + )) + + return TcbEvaluationDataNumbers( + id=inner.get("id", ""), + version=inner.get("version", 0), + issue_date=_parse_datetime(inner.get("issueDate", "1970-01-01T00:00:00Z")), + next_update=_parse_datetime(inner.get("nextUpdate", "1970-01-01T00:00:00Z")), + tcb_eval_numbers=tcb_eval_numbers, + signature=data.get("signature", ""), + ) + + +def verify_tcb_eval_numbers_signature( + response_bytes: bytes, + data: TcbEvaluationDataNumbers, + issuer_chain: List[x509.Certificate], +) -> None: + """ + Verify TCB evaluation data numbers signature against the issuer certificate chain. + + Args: + response_bytes: Raw response bytes + data: Parsed TcbEvaluationDataNumbers + issuer_chain: Issuer certificate chain from response header + + Raises: + CollateralError: If verification fails + """ + try: + verify_intel_chain(issuer_chain, "TCB eval numbers issuer chain") + except CertificateChainError as e: + raise CollateralError(str(e)) + + _verify_collateral_signature( + json_bytes=response_bytes, + json_key="tcbEvaluationDataNumbers", + signature_hex=data.signature, + signing_cert=issuer_chain[0], + data_name="TCB evaluation data numbers", + ) + + +def fetch_tcb_evaluation_data_numbers( + timeout: float = 30.0, session: HttpSession = None, +) -> TcbEvaluationDataNumbers: + """ + Fetch TCB evaluation data numbers from Intel PCS. + + This endpoint returns all TCB evaluation data numbers with their + TCB recovery event dates, which can be used to determine the minimum + acceptable tcbEvaluationDataNumber based on age. + + Args: + timeout: Request timeout in seconds + session: Optional requests.Session for connection reuse / testing + + Returns: + Parsed TcbEvaluationDataNumbers + + Raises: + CollateralError: If fetching or parsing fails + """ + url = f"{INTEL_PCS_TDX_BASE_URL}/tcbevaluationdatanumbers" + response = _http_get(url, timeout, session) + + raw_bytes = response.content + data = _parse_tcb_eval_numbers_response(raw_bytes) + + issuer_chain_header = response.headers.get("TCB-Evaluation-Data-Numbers-Issuer-Chain") + if not issuer_chain_header: + raise CollateralError( + "TCB evaluation data numbers response missing " + "TCB-Evaluation-Data-Numbers-Issuer-Chain header" + ) + + issuer_chain = _parse_issuer_chain_header(issuer_chain_header) + verify_tcb_eval_numbers_signature(raw_bytes, data, issuer_chain) + + return data + + +def calculate_min_tcb_evaluation_data_number( + max_age_days: int = 365, + timeout: float = 30.0, + session: HttpSession = None, +) -> int: + """ + Calculate the minimum acceptable tcbEvaluationDataNumber based on age. + + This function queries Intel PCS to find the lowest tcbEvaluationDataNumber + whose TCB recovery event date is within the specified maximum age. This + ensures that collateral from older TCB recovery events is rejected. + + The response signature is verified against the Intel SGX root certificate + when the issuer chain header is present. + + Args: + max_age_days: Maximum age in days (default: 365 = 1 year) + timeout: Request timeout in seconds + session: Optional requests.Session for connection reuse / testing + + Returns: + Minimum acceptable tcbEvaluationDataNumber + + Raises: + CollateralError: If calculation fails or no valid number found + + Example: + >>> min_num = calculate_min_tcb_evaluation_data_number() + >>> print(f"Minimum acceptable tcbEvaluationDataNumber: {min_num}") + """ + cutoff_date = datetime.now(timezone.utc) - timedelta(days=max_age_days) + + data = fetch_tcb_evaluation_data_numbers(timeout, session=session) + + if not data.tcb_eval_numbers: + raise CollateralError("No TCB evaluation data numbers found in Intel PCS response") + + # Sort by number ascending to find the lowest acceptable + sorted_numbers = sorted(data.tcb_eval_numbers, key=lambda x: x.tcb_evaluation_data_number) + + # Find the lowest number whose tcbRecoveryEventDate is >= cutoff_date + for item in sorted_numbers: + if item.tcb_recovery_event_date >= cutoff_date: + return item.tcb_evaluation_data_number + + # If all numbers are too old, return the highest (most recent) one + # This is a fallback - in practice, at least the most recent should be recent + highest = sorted_numbers[-1] + raise CollateralError( + f"All TCB evaluation data numbers are older than {max_age_days} days. " + f"Most recent is {highest.tcb_evaluation_data_number} from " + f"{highest.tcb_recovery_event_date}." + ) + + +# ============================================================================= +# TCB Level Matching and Validation +# ============================================================================= + +def is_cpu_svn_higher_or_equal( + pck_cert_cpu_svn: bytes, + sgx_tcb_components: List[TcbComponent], +) -> bool: + """ + Check if PCK certificate CPU SVN is >= TCB level SGX components. + + Args: + pck_cert_cpu_svn: CPU SVN from PCK certificate (16 bytes) + sgx_tcb_components: SGX TCB components from TCB level + + Returns: + True if all components are >= + """ + if len(pck_cert_cpu_svn) != len(sgx_tcb_components): + return False + + for i, component in enumerate(sgx_tcb_components): + if pck_cert_cpu_svn[i] < component.svn: + return False + + return True + + +def is_tdx_tcb_svn_higher_or_equal( + tee_tcb_svn: bytes, + tdx_tcb_components: List[TcbComponent], +) -> bool: + """ + Check if TEE TCB SVN is >= TCB level TDX components. + + Args: + tee_tcb_svn: TEE TCB SVN from quote body (16 bytes) + tdx_tcb_components: TDX TCB components from TCB level + + Returns: + True if all relevant components are >= + """ + if len(tee_tcb_svn) != len(tdx_tcb_components): + return False + + # If teeTcbSvn[1] > 0, skip first 2 bytes (module-specific behavior) + start = 0 + if len(tee_tcb_svn) > 1 and tee_tcb_svn[1] > 0: + start = 2 + + for i in range(start, len(tee_tcb_svn)): + if tee_tcb_svn[i] < tdx_tcb_components[i].svn: + return False + + return True + + +def get_matching_tcb_level( + tcb_levels: List[TcbLevel], + tee_tcb_svn: bytes, + pck_cert_pce_svn: int, + pck_cert_cpu_svn: bytes, +) -> Optional[TcbLevel]: + """ + Find the matching TCB level for the given SVN values. + + TCB levels are ordered from newest to oldest. Returns the first + level where all three checks pass. + + Args: + tcb_levels: List of TCB levels from TCB Info + tee_tcb_svn: TEE TCB SVN from quote body + pck_cert_pce_svn: PCE SVN from PCK certificate + pck_cert_cpu_svn: CPU SVN from PCK certificate + + Returns: + Matching TcbLevel or None if no match found + """ + for level in tcb_levels: + if (is_cpu_svn_higher_or_equal(pck_cert_cpu_svn, level.tcb.sgx_tcb_components) and + pck_cert_pce_svn >= level.tcb.pce_svn and + is_tdx_tcb_svn_higher_or_equal(tee_tcb_svn, level.tcb.tdx_tcb_components)): + return level + + return None + + +def get_matching_qe_tcb_level( + tcb_levels: List[TcbLevel], + isv_svn: int, +) -> Optional[TcbLevel]: + """ + Find the matching TCB level for the QE ISV SVN. + + TCB levels are ordered from newest to oldest. Returns the first + level where the report's ISV SVN >= the level's ISV SVN. + + Args: + tcb_levels: List of TCB levels from QE Identity + isv_svn: ISV SVN from QE report + + Returns: + Matching TcbLevel or None if no match found + """ + for level in tcb_levels: + # QE TCB levels use isvsvn in the tcb structure + # The level matches if report's ISV SVN >= level's ISV SVN + if level.tcb.isv_svn is not None and isv_svn >= level.tcb.isv_svn: + return level + + return None + + +def validate_tcb_status( + tcb_info: TcbInfo, + tee_tcb_svn: bytes, + pck_extensions: PckExtensions, + strict_mode: bool = False, +) -> TcbLevel: + """ + Validate TCB status against Intel's published levels. + + Performs the following checks: + 1. Cross-validates FMSPC and PCE_ID between PCK cert and TCB Info + 2. Finds matching TCB level for platform SVN values + 3. Checks TCB status (REVOKED and OUT_OF_DATE always rejected) + + Args: + tcb_info: Parsed TCB Info from Intel PCS + tee_tcb_svn: TEE TCB SVN from quote body + pck_extensions: PCK certificate extensions + strict_mode: If True, reject SW_HARDENING_NEEDED and CONFIGURATION_NEEDED + + Returns: + Matching TcbLevel + + Raises: + CollateralError: If TCB status is not acceptable + """ + # Cross-check FMSPC between PCK certificate and TCB Info + pck_fmspc = pck_extensions.fmspc.lower() + tcb_fmspc = tcb_info.fmspc.lower() + if pck_fmspc != tcb_fmspc: + raise CollateralError( + f"FMSPC mismatch: PCK certificate has {pck_fmspc}, " + f"TCB Info has {tcb_fmspc}" + ) + + # Cross-check PCE_ID between PCK certificate and TCB Info + pck_pceid = pck_extensions.pceid.lower() + tcb_pceid = tcb_info.pce_id.lower() + if pck_pceid != tcb_pceid: + raise CollateralError( + f"PCE_ID mismatch: PCK certificate has {pck_pceid}, " + f"TCB Info has {tcb_pceid}" + ) + + matching_level = get_matching_tcb_level( + tcb_levels=tcb_info.tcb_levels, + tee_tcb_svn=tee_tcb_svn, + pck_cert_pce_svn=pck_extensions.tcb.pce_svn, + pck_cert_cpu_svn=pck_extensions.tcb.cpu_svn, + ) + + if matching_level is None: + raise CollateralError( + "No matching TCB level found for the quote's SVN values" + ) + + # Check status - always reject REVOKED and OUT_OF_DATE + if matching_level.tcb_status == TcbStatus.REVOKED: + raise CollateralError("TCB status is REVOKED - platform is not trusted") + + if matching_level.tcb_status == TcbStatus.OUT_OF_DATE: + raise CollateralError("TCB status is OUT_OF_DATE - platform needs update") + + if matching_level.tcb_status == TcbStatus.OUT_OF_DATE_CONFIGURATION_NEEDED: + raise CollateralError( + "TCB status is OUT_OF_DATE_CONFIGURATION_NEEDED - platform needs update" + ) + + # In strict mode, also reject statuses that indicate security advisories + if strict_mode: + if matching_level.tcb_status == TcbStatus.SW_HARDENING_NEEDED: + advisories = ", ".join(matching_level.advisory_ids) or "none listed" + raise CollateralError( + f"TCB status is SW_HARDENING_NEEDED (strict mode). " + f"Advisories: {advisories}" + ) + + if matching_level.tcb_status == TcbStatus.CONFIGURATION_NEEDED: + advisories = ", ".join(matching_level.advisory_ids) or "none listed" + raise CollateralError( + f"TCB status is CONFIGURATION_NEEDED (strict mode). " + f"Advisories: {advisories}" + ) + + if matching_level.tcb_status == TcbStatus.CONFIGURATION_AND_SW_HARDENING_NEEDED: + advisories = ", ".join(matching_level.advisory_ids) or "none listed" + raise CollateralError( + f"TCB status is CONFIGURATION_AND_SW_HARDENING_NEEDED (strict mode). " + f"Advisories: {advisories}" + ) + + return matching_level + + +def get_tdx_module_identity( + tcb_info: TcbInfo, + tee_tcb_svn: bytes, +) -> TdxModuleIdentity: + """ + Find the matching TDX module identity based on TEE_TCB_SVN. + + The module ID is derived from TEE_TCB_SVN[1] (major version): + - TEE_TCB_SVN[1] = 0x03 -> Module ID = "TDX_03" + - TEE_TCB_SVN[1] = 0x01 -> Module ID = "TDX_01" + + Args: + tcb_info: Parsed TCB Info from Intel PCS + tee_tcb_svn: TEE TCB SVN from quote body (16 bytes) + + Returns: + Matching TdxModuleIdentity + + Raises: + CollateralError: If no module identities in TCB Info, tee_tcb_svn is + malformed, or module version is unknown + """ + if not tcb_info.tdx_module_identities: + raise CollateralError("TCB Info has no TDX module identities") + + if len(tee_tcb_svn) < 2: + raise CollateralError( + f"TEE_TCB_SVN is too short ({len(tee_tcb_svn)} bytes, expected >= 2)" + ) + + # Extract major version and form module ID + # TEE_TCB_SVN[0] = minor SVN, TEE_TCB_SVN[1] = major SVN + major_version = tee_tcb_svn[1] + module_id = f"TDX_{major_version:02d}" + + # Find matching module identity + for module_identity in tcb_info.tdx_module_identities: + if module_identity.id == module_id: + return module_identity + + known_ids = [m.id for m in tcb_info.tdx_module_identities] + raise CollateralError( + f"Unknown TDX module version {module_id}. Known module identities: {known_ids}" + ) + + +def validate_tdx_module_identity( + tcb_info: TcbInfo, + tee_tcb_svn: bytes, + mr_signer_seam: bytes, + seam_attributes: bytes, +) -> None: + """ + Validate TDX module identity against Intel's published identities. + + This validates: + - MR_SIGNER_SEAM matches the expected module signer + - SEAM_ATTRIBUTES match under the attribute mask + - Module-specific TCB level matches the minor version (TEE_TCB_SVN[0]) + + Args: + tcb_info: Parsed TCB Info from Intel PCS + tee_tcb_svn: TEE TCB SVN from quote body (16 bytes) + mr_signer_seam: MR_SIGNER_SEAM from quote body (48 bytes) + seam_attributes: SEAM_ATTRIBUTES from quote body (8 bytes) + + Raises: + CollateralError: If module identity validation fails + """ + module_identity = get_tdx_module_identity(tcb_info, tee_tcb_svn) + + # Verify MR_SIGNER_SEAM matches + if mr_signer_seam != module_identity.mrsigner: + raise CollateralError( + f"TDX module MR_SIGNER_SEAM does not match expected value. " + f"Got {mr_signer_seam.hex()}, expected {module_identity.mrsigner.hex()}" + ) + + # Verify SEAM_ATTRIBUTES match under mask + # Pad seam_attributes to match mask length if needed + mask = module_identity.attributes_mask + attrs = seam_attributes + + # Handle length mismatch by padding with zeros + if len(attrs) < len(mask): + attrs = attrs + b'\x00' * (len(mask) - len(attrs)) + if len(mask) < len(attrs): + mask = mask + b'\x00' * (len(attrs) - len(mask)) + + report_attrs_masked = bytes(a & b for a, b in zip(attrs, mask)) + expected_attrs_masked = bytes( + a & b for a, b in zip(module_identity.attributes, mask) + ) + if report_attrs_masked != expected_attrs_masked: + raise CollateralError( + f"TDX module SEAM_ATTRIBUTES do not match expected value under mask. " + f"Got {seam_attributes.hex()}, expected {module_identity.attributes.hex()} " + f"(mask {module_identity.attributes_mask.hex()})" + ) + + # Find matching module-specific TCB level + # TEE_TCB_SVN[0] = minor SVN, TEE_TCB_SVN[1] = major SVN + # The minor version is used for module TCB matching + minor_version = tee_tcb_svn[0] + + for level in module_identity.tcb_levels: + # Module TCB levels should have isv_svn (minor version) set + if level.tcb.isv_svn is not None and minor_version >= level.tcb.isv_svn: + # Check status - reject insecure statuses consistently with + # validate_tcb_status and validate_qe_identity + if level.tcb_status == TcbStatus.REVOKED: + raise CollateralError( + f"TDX module TCB status is REVOKED for version " + f"{tee_tcb_svn[1]}.{minor_version}" + ) + if level.tcb_status == TcbStatus.OUT_OF_DATE: + raise CollateralError( + f"TDX module TCB status is OUT_OF_DATE for version " + f"{tee_tcb_svn[1]}.{minor_version}" + ) + if level.tcb_status == TcbStatus.OUT_OF_DATE_CONFIGURATION_NEEDED: + raise CollateralError( + f"TDX module TCB status is OUT_OF_DATE_CONFIGURATION_NEEDED " + f"for version {tee_tcb_svn[1]}.{minor_version}" + ) + return + + # No matching TCB level found - this is an error (matches Go behavior) + raise CollateralError( + f"Could not find a TDX Module Identity TCB Level matching " + f"the TDX Module's ISVSVN ({minor_version})" + ) + + +def validate_qe_identity( + qe_identity: EnclaveIdentity, + qe_report_isv_svn: int, + qe_report_mrsigner: bytes, + qe_report_miscselect: bytes, + qe_report_attributes: bytes, + qe_report_isvprodid: int, +) -> TcbLevel: + """ + Validate QE identity against Intel's published identity. + + This performs comprehensive validation of the Quoting Enclave: + - MRSIGNER must match exactly + - MISCSELECT must match under mask + - Attributes must match under mask + - ISV ProdID must match exactly + - ISV SVN must meet minimum threshold for a TCB level + + Args: + qe_identity: Parsed QE Identity from Intel PCS + qe_report_isv_svn: ISV SVN from QE report + qe_report_mrsigner: MRSIGNER from QE report (32 bytes) + qe_report_miscselect: MISCSELECT from QE report (4 bytes) + qe_report_attributes: Attributes from QE report (16 bytes) + qe_report_isvprodid: ISV ProdID from QE report + + Returns: + Matching TcbLevel + + Raises: + CollateralError: If QE identity validation fails + """ + # Verify MRSIGNER matches exactly + if qe_report_mrsigner != qe_identity.mrsigner: + raise CollateralError( + f"QE report MRSIGNER does not match expected value. " + f"Got {qe_report_mrsigner.hex()}, expected {qe_identity.mrsigner.hex()}" + ) + + # Verify MISCSELECT matches under mask + # (qe_report_miscselect & mask) == (qe_identity.miscselect & mask) + report_miscselect_masked = bytes( + a & b for a, b in zip(qe_report_miscselect, qe_identity.miscselect_mask) + ) + expected_miscselect_masked = bytes( + a & b for a, b in zip(qe_identity.miscselect, qe_identity.miscselect_mask) + ) + if report_miscselect_masked != expected_miscselect_masked: + raise CollateralError( + f"QE report MISCSELECT does not match expected value under mask. " + f"Got {qe_report_miscselect.hex()}, expected {qe_identity.miscselect.hex()} " + f"(mask {qe_identity.miscselect_mask.hex()})" + ) + + # Verify Attributes match under mask + # (qe_report_attributes & mask) == (qe_identity.attributes & mask) + report_attributes_masked = bytes( + a & b for a, b in zip(qe_report_attributes, qe_identity.attributes_mask) + ) + expected_attributes_masked = bytes( + a & b for a, b in zip(qe_identity.attributes, qe_identity.attributes_mask) + ) + if report_attributes_masked != expected_attributes_masked: + raise CollateralError( + f"QE report Attributes do not match expected value under mask. " + f"Got {qe_report_attributes.hex()}, expected {qe_identity.attributes.hex()} " + f"(mask {qe_identity.attributes_mask.hex()})" + ) + + # Verify ISV ProdID matches exactly + if qe_report_isvprodid != qe_identity.isv_prod_id: + raise CollateralError( + f"QE report ISV ProdID does not match expected value. " + f"Got {qe_report_isvprodid}, expected {qe_identity.isv_prod_id}" + ) + + # Find matching TCB level + matching_level = get_matching_qe_tcb_level( + tcb_levels=qe_identity.tcb_levels, + isv_svn=qe_report_isv_svn, + ) + + if matching_level is None: + raise CollateralError( + f"No matching QE TCB level found for ISV SVN {qe_report_isv_svn}" + ) + + # Check status + if matching_level.tcb_status == TcbStatus.REVOKED: + raise CollateralError("QE TCB status is REVOKED") + + if matching_level.tcb_status == TcbStatus.OUT_OF_DATE: + raise CollateralError("QE TCB status is OUT_OF_DATE") + + if matching_level.tcb_status == TcbStatus.OUT_OF_DATE_CONFIGURATION_NEEDED: + raise CollateralError("QE TCB status is OUT_OF_DATE_CONFIGURATION_NEEDED") + + return matching_level + + +def check_collateral_freshness( + collateral: TdxCollateral, + min_tcb_evaluation_data_number: int, +) -> None: + """ + Check that collateral is not expired and meets freshness requirements. + + This performs the following checks: + 1. TCB Info has not expired (now < next_update) + 2. QE Identity has not expired (now < next_update) + 3. Both TCB Info and QE Identity must have tcbEvaluationDataNumber >= threshold + + The tcbEvaluationDataNumber is a monotonically increasing number that + Intel updates when new TCB recovery (TCB-R) events occur. The minimum + threshold ensures we don't accept collateral issued before critical + security updates. + + See: https://www.intel.com/content/www/us/en/developer/topic-technology/software-security-guidance/trusted-computing-base-recovery-attestation.html + + Args: + collateral: TDX collateral to check + min_tcb_evaluation_data_number: Minimum tcbEvaluationDataNumber threshold. + Collateral with a lower number is rejected. + + Raises: + CollateralError: If collateral is expired or too old + """ + now = datetime.now(timezone.utc) + + tcb_next_update = collateral.tcb_info.tcb_info.next_update + if now > tcb_next_update: + raise CollateralError( + f"TCB Info has expired (next update was {tcb_next_update})" + ) + + qe_next_update = collateral.qe_identity.enclave_identity.next_update + if now > qe_next_update: + raise CollateralError( + f"QE Identity has expired (next update was {qe_next_update})" + ) + + # Check tcbEvaluationDataNumber threshold + tcb_eval_num = collateral.tcb_info.tcb_info.tcb_evaluation_data_number + if tcb_eval_num < min_tcb_evaluation_data_number: + raise CollateralError( + f"TCB Info tcbEvaluationDataNumber ({tcb_eval_num}) is below " + f"minimum required ({min_tcb_evaluation_data_number}). " + f"Collateral may be outdated." + ) + + qe_eval_num = collateral.qe_identity.enclave_identity.tcb_evaluation_data_number + if qe_eval_num < min_tcb_evaluation_data_number: + raise CollateralError( + f"QE Identity tcbEvaluationDataNumber ({qe_eval_num}) is below " + f"minimum required ({min_tcb_evaluation_data_number}). " + f"Collateral may be outdated." + ) + + +def validate_certificate_revocation( + collateral: TdxCollateral, + pck_cert: x509.Certificate, + intermediate_cert: Optional[x509.Certificate] = None, +) -> None: + """ + Validate that certificates in the attestation chain have not been revoked. + + Per Intel DCAP spec Section 2.3, checks: + 1. PCK (leaf) certificate against the PCK CRL from Intel PCS + 2. Intermediate CA certificate against the Intel SGX Root CA CRL + 3. TCB Info signing certificate against the Intel SGX Root CA CRL + 4. QE Identity signing certificate against the Intel SGX Root CA CRL + + Args: + collateral: TDX collateral containing CRLs and issuer chains + pck_cert: PCK certificate to check + intermediate_cert: Intermediate CA certificate to check (optional) + + Raises: + CollateralError: If any certificate is revoked or CRL check fails + + TODO: Use cryptography.x509.verification integrated CRL checking when available. + As of cryptography 46.0.3, PolicyBuilder doesn't support CRL stores. + Track: https://github.com/pyca/cryptography/issues + """ + if collateral.pck_crl is None: + raise CollateralError( + "Cannot check certificate revocation: PCK CRL not available in collateral" + ) + + # --- Check PCK (leaf) certificate against PCK CRL --- + pck_crl = collateral.pck_crl.crl + + # Check PCK CRL freshness + now = datetime.now(timezone.utc) + pck_crl_next_update = pck_crl.next_update_utc + if pck_crl_next_update is not None and now > pck_crl_next_update: + raise CollateralError( + f"PCK CRL has expired (next update was {pck_crl_next_update})" + ) + + # Check if PCK certificate is revoked + pck_serial = pck_cert.serial_number + revoked_pck = pck_crl.get_revoked_certificate_by_serial_number(pck_serial) + + if revoked_pck is not None: + revocation_date = revoked_pck.revocation_date_utc + raise CollateralError( + f"PCK certificate has been revoked. " + f"Serial: {pck_serial:x}, Revocation date: {revocation_date}" + ) + + # --- Check certificates against Root CA CRL --- + # The Root CRL lists revoked intermediate CAs (Platform CA, Processor CA, + # TCB Signing CA, etc.) issued by Intel SGX Root CA. + if collateral.root_crl is not None: + root_crl = collateral.root_crl.crl + + # Check Root CRL freshness (once, shared across all checks below) + root_crl_next_update = root_crl.next_update_utc + if root_crl_next_update is not None and now > root_crl_next_update: + raise CollateralError( + f"Intel SGX Root CA CRL has expired (next update was {root_crl_next_update})" + ) + + # Check intermediate CA certificate (from PCK cert chain) + if intermediate_cert is not None: + _check_cert_against_crl( + root_crl, intermediate_cert, "Intermediate CA" + ) + + # Check TCB Info signing certificate (per DCAP spec Section 2.3: + # "Check if verification collaterals are on the CRL") + if collateral.tcb_info_issuer_chain: + _check_cert_against_crl( + root_crl, collateral.tcb_info_issuer_chain[0], "TCB Info signing" + ) + + # Check QE Identity signing certificate + if collateral.qe_identity_issuer_chain: + _check_cert_against_crl( + root_crl, collateral.qe_identity_issuer_chain[0], "QE Identity signing" + ) + + elif intermediate_cert is not None or collateral.tcb_info_issuer_chain or collateral.qe_identity_issuer_chain: + raise CollateralError( + "Cannot check certificate revocation: Root CRL not available in collateral" + ) + + +def _check_cert_against_crl( + crl: x509.CertificateRevocationList, + cert: x509.Certificate, + cert_name: str, +) -> None: + """ + Check if a certificate has been revoked according to a CRL. + + Args: + crl: Certificate Revocation List to check against + cert: Certificate to check + cert_name: Human-readable name for error messages + + Raises: + CollateralError: If the certificate has been revoked + """ + serial = cert.serial_number + revoked = crl.get_revoked_certificate_by_serial_number(serial) + if revoked is not None: + revocation_date = revoked.revocation_date_utc + raise CollateralError( + f"{cert_name} certificate has been revoked. " + f"Serial: {serial:x}, Revocation date: {revocation_date}" + ) + + +# ============================================================================= +# Collateral Validation Orchestration +# ============================================================================= + +@dataclass +class CollateralValidationResult: + """ + Result of collateral validation. + + Attributes: + collateral: The fetched TDX collateral + tcb_level: The matching TCB level for the platform + pck_extensions: Extracted PCK certificate extensions + """ + collateral: TdxCollateral + tcb_level: TcbLevel + pck_extensions: PckExtensions + + +def validate_collateral( + quote: "QuoteV4", + pck_chain: "PCKCertificateChain", + min_tcb_evaluation_data_number: int = 18, +) -> CollateralValidationResult: + """ + Validate all collateral for a TDX quote. + + This function orchestrates the complete collateral validation flow: + 1. Extract PCK extensions (FMSPC, TCB components) + 2. Fetch collateral from Intel PCS (TCB Info, QE Identity) + 3. Check collateral freshness + 4. Validate certificate revocation + 5. Validate TCB status + 6. Validate TDX module identity + 7. Validate QE identity + + Args: + quote: Parsed TDX QuoteV4 + pck_chain: Extracted PCK certificate chain + min_tcb_evaluation_data_number: Minimum required tcbEvaluationDataNumber + + Returns: + CollateralValidationResult with validated collateral and TCB level + + Raises: + CollateralError: If any collateral validation step fails + """ + + + # Step 1: Extract PCK extensions + try: + pck_extensions = extract_pck_extensions(pck_chain.pck_cert) + except PckExtensionError as e: + raise CollateralError(f"Failed to extract PCK extensions: {e}") + + # Step 2: Fetch collateral from Intel PCS + collateral = fetch_collateral(pck_extensions, pck_chain.pck_cert) + + # Step 3: Check collateral freshness + check_collateral_freshness(collateral, min_tcb_evaluation_data_number) + + # Step 4: Validate certificate revocation + validate_certificate_revocation( + collateral, pck_chain.pck_cert, pck_chain.intermediate_cert + ) + + # Step 5: Validate TCB status + tcb_level = validate_tcb_status( + collateral.tcb_info.tcb_info, + quote.td_quote_body.tee_tcb_svn, + pck_extensions, + ) + + # Step 6: Validate TDX module identity + validate_tdx_module_identity( + collateral.tcb_info.tcb_info, + quote.td_quote_body.tee_tcb_svn, + quote.td_quote_body.mr_signer_seam, + quote.td_quote_body.seam_attributes, + ) + + # Step 7: Validate QE identity + qe_report = quote.signed_data.certification_data.qe_report_data + if qe_report is None: + raise CollateralError("Quote missing QE report certification data") + qe_parsed = qe_report.qe_report_parsed + miscselect_bytes = qe_parsed.misc_select.to_bytes(4, byteorder='little') + validate_qe_identity( + collateral.qe_identity.enclave_identity, + qe_parsed.isv_svn, + qe_parsed.mr_signer, + miscselect_bytes, + qe_parsed.attributes, + qe_parsed.isv_prod_id, + ) + + return CollateralValidationResult( + collateral=collateral, + tcb_level=tcb_level, + pck_extensions=pck_extensions, + ) diff --git a/tests/test_collateral_tdx.py b/tests/test_collateral_tdx.py new file mode 100644 index 0000000..0735e70 --- /dev/null +++ b/tests/test_collateral_tdx.py @@ -0,0 +1,1892 @@ +""" +Unit tests for TDX collateral fetching and TCB validation. +""" + +import json +import pytest +from datetime import datetime, timezone, timedelta +from unittest.mock import patch, MagicMock + +from tinfoil.attestation.collateral_tdx import ( + TcbStatus, + TcbComponent, + Tcb, + TcbLevel, + TcbInfo, + TdxTcbInfo, + TdxModuleIdentity, + EnclaveIdentity, + QeIdentity, + TdxCollateral, + CollateralError, + parse_tcb_info_response, + parse_qe_identity_response, + is_cpu_svn_higher_or_equal, + is_tdx_tcb_svn_higher_or_equal, + get_matching_tcb_level, + get_matching_qe_tcb_level, + get_tdx_module_identity, + validate_tcb_status, + validate_tdx_module_identity, + validate_qe_identity, + check_collateral_freshness, + fetch_tcb_info, + fetch_qe_identity, + _parse_datetime, + _parse_issuer_chain_header, + _verify_collateral_signature, + _parse_hex_bytes, + _get_tcb_info_cache_path, + _get_qe_identity_cache_path, + _is_tcb_info_fresh, + _is_qe_identity_fresh, + _read_cache, + _write_cache, +) +from tinfoil.attestation.pck_extensions import PckExtensions, PckCertTCB + + +# ============================================================================= +# Sample Data - Based on Real Intel PCS Responses +# ============================================================================= + +SAMPLE_TCB_INFO_JSON = """ +{ + "tcbInfo": { + "id": "TDX", + "version": 3, + "issueDate": "2025-12-17T06:24:56Z", + "nextUpdate": "2026-01-16T06:24:56Z", + "fmspc": "90c06f000000", + "pceId": "0000", + "tcbType": 0, + "tcbEvaluationDataNumber": 18, + "tdxModule": { + "mrsigner": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + "attributes": "0000000000000000", + "attributesMask": "FFFFFFFFFFFFFFFF" + }, + "tdxModuleIdentities": [ + { + "id": "TDX_03", + "mrsigner": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + "attributes": "0000000000000000", + "attributesMask": "FFFFFFFFFFFFFFFF", + "tcbLevels": [ + { + "tcb": { "isvsvn": 3 }, + "tcbDate": "2024-11-13T00:00:00Z", + "tcbStatus": "UpToDate" + } + ] + } + ], + "tcbLevels": [ + { + "tcb": { + "sgxtcbcomponents": [ + {"svn": 3}, {"svn": 3}, {"svn": 2}, {"svn": 2}, + {"svn": 2}, {"svn": 1}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0} + ], + "pcesvn": 13, + "tdxtcbcomponents": [ + {"svn": 5, "category": "TDX Module", "type": "TDX Module"}, + {"svn": 0}, {"svn": 3}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0} + ] + }, + "tcbDate": "2024-11-13T00:00:00Z", + "tcbStatus": "UpToDate" + }, + { + "tcb": { + "sgxtcbcomponents": [ + {"svn": 2}, {"svn": 2}, {"svn": 2}, {"svn": 2}, + {"svn": 2}, {"svn": 1}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0} + ], + "pcesvn": 11, + "tdxtcbcomponents": [ + {"svn": 4}, {"svn": 0}, {"svn": 2}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0} + ] + }, + "tcbDate": "2024-03-13T00:00:00Z", + "tcbStatus": "OutOfDate" + } + ] + }, + "signature": "abcd1234" +} +""" + +SAMPLE_QE_IDENTITY_JSON = """ +{ + "enclaveIdentity": { + "id": "TD_QE", + "version": 2, + "issueDate": "2025-12-17T18:48:11Z", + "nextUpdate": "2026-01-16T18:48:11Z", + "tcbEvaluationDataNumber": 18, + "miscselect": "00000000", + "miscselectMask": "FFFFFFFF", + "attributes": "11000000000000000000000000000000", + "attributesMask": "FBFFFFFFFFFFFFFF0000000000000000", + "mrsigner": "DC9E2A7C6F948F17474E34A7FC43ED030F7C1563F1BABDDF6340C82E0E54A8C5", + "isvprodid": 2, + "tcbLevels": [ + { + "tcb": { "isvsvn": 4 }, + "tcbDate": "2024-11-13T00:00:00Z", + "tcbStatus": "UpToDate" + }, + { + "tcb": { "isvsvn": 3 }, + "tcbDate": "2024-03-13T00:00:00Z", + "tcbStatus": "OutOfDate" + } + ] + }, + "signature": "0665a932" +} +""" + +# Byte versions for cache testing +SAMPLE_TCB_INFO_RESPONSE = SAMPLE_TCB_INFO_JSON.encode() +SAMPLE_QE_IDENTITY_RESPONSE = SAMPLE_QE_IDENTITY_JSON.encode() + +# Stale TCB Info (next_update in the past) +SAMPLE_STALE_TCB_INFO_JSON = """ +{ + "tcbInfo": { + "id": "TDX", + "version": 3, + "issueDate": "2024-01-17T06:24:56Z", + "nextUpdate": "2024-02-16T06:24:56Z", + "fmspc": "90c06f000000", + "pceId": "0000", + "tcbType": 0, + "tcbEvaluationDataNumber": 18, + "tdxModuleIdentities": [], + "tcbLevels": [] + }, + "signature": "stale1234" +} +""" +SAMPLE_STALE_TCB_INFO_RESPONSE = SAMPLE_STALE_TCB_INFO_JSON.encode() + + +# ============================================================================= +# Helper Functions +# ============================================================================= + +def create_sample_pck_extensions() -> PckExtensions: + """Create sample PCK extensions for testing.""" + return PckExtensions( + ppid="00" * 16, + tcb=PckCertTCB( + pce_svn=13, + cpu_svn=bytes([3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + tcb_components=[3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ), + pceid="0000", + fmspc="90c06f000000", + ) + + +# ============================================================================= +# Parsing Tests +# ============================================================================= + +class TestParseTcbInfoResponse: + """Test TCB Info parsing.""" + + def test_parse_valid_response(self): + """Test parsing a valid TCB Info response.""" + result = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + assert result.tcb_info.id == "TDX" + assert result.tcb_info.version == 3 + assert result.tcb_info.fmspc == "90c06f000000" + assert len(result.tcb_info.tcb_levels) == 2 + assert result.signature == "abcd1234" + + def test_parse_tcb_levels(self): + """Test that TCB levels are parsed correctly.""" + result = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + first_level = result.tcb_info.tcb_levels[0] + assert first_level.tcb_status == TcbStatus.UP_TO_DATE + assert first_level.tcb.pce_svn == 13 + assert len(first_level.tcb.sgx_tcb_components) == 16 + assert first_level.tcb.sgx_tcb_components[0].svn == 3 + + def test_parse_tdx_module_identities(self): + """Test that TDX module identities are parsed.""" + result = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + assert len(result.tcb_info.tdx_module_identities) == 1 + identity = result.tcb_info.tdx_module_identities[0] + assert identity.id == "TDX_03" + + def test_reject_wrong_id(self): + """Test that non-TDX ID is rejected.""" + bad_json = SAMPLE_TCB_INFO_JSON.replace('"id": "TDX"', '"id": "SGX"') + with pytest.raises(CollateralError, match="must be 'TDX'"): + parse_tcb_info_response(bad_json.encode()) + + def test_reject_wrong_version(self): + """Test that wrong version is rejected.""" + bad_json = SAMPLE_TCB_INFO_JSON.replace('"version": 3', '"version": 2') + with pytest.raises(CollateralError, match="must be 3"): + parse_tcb_info_response(bad_json.encode()) + + +class TestParseQeIdentityResponse: + """Test QE Identity parsing.""" + + def test_parse_valid_response(self): + """Test parsing a valid QE Identity response.""" + result = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + assert result.enclave_identity.id == "TD_QE" + assert result.enclave_identity.version == 2 + assert result.enclave_identity.isv_prod_id == 2 + assert len(result.enclave_identity.tcb_levels) == 2 + + def test_parse_mrsigner(self): + """Test MRSIGNER is parsed correctly.""" + result = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + expected = bytes.fromhex( + "DC9E2A7C6F948F17474E34A7FC43ED030F7C1563F1BABDDF6340C82E0E54A8C5" + ) + assert result.enclave_identity.mrsigner == expected + + def test_reject_wrong_id(self): + """Test that non-TD_QE ID is rejected.""" + bad_json = SAMPLE_QE_IDENTITY_JSON.replace('"id": "TD_QE"', '"id": "QE"') + with pytest.raises(CollateralError, match="must be 'TD_QE'"): + parse_qe_identity_response(bad_json.encode()) + + +# ============================================================================= +# TCB Comparison Tests +# ============================================================================= + +class TestIsCpuSvnHigherOrEqual: + """Test CPU SVN comparison.""" + + def test_equal_values(self): + """Test equal SVN values pass.""" + cpu_svn = bytes([3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + components = [TcbComponent(svn=3), TcbComponent(svn=3), TcbComponent(svn=2), + TcbComponent(svn=2), TcbComponent(svn=2), TcbComponent(svn=1), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_cpu_svn_higher_or_equal(cpu_svn, components) is True + + def test_higher_values(self): + """Test higher SVN values pass.""" + cpu_svn = bytes([4, 4, 3, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + components = [TcbComponent(svn=3), TcbComponent(svn=3), TcbComponent(svn=2), + TcbComponent(svn=2), TcbComponent(svn=2), TcbComponent(svn=1), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_cpu_svn_higher_or_equal(cpu_svn, components) is True + + def test_lower_first_component(self): + """Test lower first component fails.""" + cpu_svn = bytes([2, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + components = [TcbComponent(svn=3), TcbComponent(svn=3), TcbComponent(svn=2), + TcbComponent(svn=2), TcbComponent(svn=2), TcbComponent(svn=1), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_cpu_svn_higher_or_equal(cpu_svn, components) is False + + def test_wrong_length(self): + """Test wrong length fails.""" + cpu_svn = bytes([3, 3, 2]) # Only 3 bytes + components = [TcbComponent(svn=3)] * 16 + + assert is_cpu_svn_higher_or_equal(cpu_svn, components) is False + + +class TestIsTdxTcbSvnHigherOrEqual: + """Test TDX TCB SVN comparison.""" + + def test_equal_values(self): + """Test equal SVN values pass.""" + tee_tcb_svn = bytes([5, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + components = [TcbComponent(svn=5), TcbComponent(svn=0), TcbComponent(svn=3), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_tdx_tcb_svn_higher_or_equal(tee_tcb_svn, components) is True + + def test_skip_first_two_when_module_version_set(self): + """Test that first 2 bytes are skipped when tee_tcb_svn[1] > 0.""" + # tee_tcb_svn[1] = 1, so first 2 bytes should be skipped + tee_tcb_svn = bytes([0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + # Components have higher values in first 2 positions, but should be skipped + components = [TcbComponent(svn=99), TcbComponent(svn=99), TcbComponent(svn=3), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_tdx_tcb_svn_higher_or_equal(tee_tcb_svn, components) is True + + def test_lower_value_fails(self): + """Test lower SVN value fails.""" + tee_tcb_svn = bytes([5, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + components = [TcbComponent(svn=5), TcbComponent(svn=0), TcbComponent(svn=3), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_tdx_tcb_svn_higher_or_equal(tee_tcb_svn, components) is False + + +class TestGetMatchingTcbLevel: + """Test TCB level matching.""" + + def test_find_matching_level(self): + """Test finding a matching TCB level.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + tee_tcb_svn = bytes([5, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + pce_svn = 13 + cpu_svn = bytes([3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + result = get_matching_tcb_level( + tcb_info.tcb_info.tcb_levels, + tee_tcb_svn, + pce_svn, + cpu_svn, + ) + + assert result is not None + assert result.tcb_status == TcbStatus.UP_TO_DATE + + def test_no_matching_level(self): + """Test no matching level found.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + # Very low SVN values - won't match any level + tee_tcb_svn = bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + pce_svn = 1 + cpu_svn = bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + result = get_matching_tcb_level( + tcb_info.tcb_info.tcb_levels, + tee_tcb_svn, + pce_svn, + cpu_svn, + ) + + assert result is None + + +class TestGetMatchingQeTcbLevel: + """Test QE TCB level matching with isvsvn.""" + + def test_find_matching_level(self): + """Test finding a matching QE TCB level by isvsvn.""" + tcb_levels = [ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=8, + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=6, + ), + tcb_date="2023-06-01T00:00:00Z", + tcb_status=TcbStatus.SW_HARDENING_NEEDED, + advisory_ids=["INTEL-SA-00001"], + ), + ] + + # ISV SVN 8 should match first level (equal) + result = get_matching_qe_tcb_level(tcb_levels, 8) + assert result is not None + assert result.tcb_status == TcbStatus.UP_TO_DATE + + # ISV SVN 10 should also match first level (>= 8) + result = get_matching_qe_tcb_level(tcb_levels, 10) + assert result is not None + assert result.tcb_status == TcbStatus.UP_TO_DATE + + # ISV SVN 7 should match second level (>= 6 but < 8) + result = get_matching_qe_tcb_level(tcb_levels, 7) + assert result is not None + assert result.tcb_status == TcbStatus.SW_HARDENING_NEEDED + + def test_no_matching_level(self): + """Test no matching QE TCB level when isvsvn too low.""" + tcb_levels = [ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=8, + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ] + + # ISV SVN 5 is less than 8, so no match + result = get_matching_qe_tcb_level(tcb_levels, 5) + assert result is None + + def test_empty_levels(self): + """Test empty TCB levels list.""" + result = get_matching_qe_tcb_level([], 8) + assert result is None + + def test_level_without_isv_svn(self): + """Test TCB level without isv_svn field is skipped.""" + tcb_levels = [ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=None, # No isv_svn set + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ] + + result = get_matching_qe_tcb_level(tcb_levels, 8) + assert result is None + + +# ============================================================================= +# Validation Tests +# ============================================================================= + +class TestValidateTcbStatus: + """Test TCB status validation.""" + + def test_up_to_date_passes(self): + """Test UpToDate status passes.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + pck_ext = create_sample_pck_extensions() + tee_tcb_svn = bytes([5, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + result = validate_tcb_status( + tcb_info.tcb_info, + tee_tcb_svn, + pck_ext, + ) + + assert result.tcb_status == TcbStatus.UP_TO_DATE + + def test_no_matching_level_fails(self): + """Test that no matching level raises error.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + # Create extensions with very low SVN values + pck_ext = PckExtensions( + ppid="00" * 16, + tcb=PckCertTCB( + pce_svn=1, + cpu_svn=bytes(16), + tcb_components=[0] * 16, + ), + pceid="0000", + fmspc="90c06f000000", + ) + tee_tcb_svn = bytes(16) + + with pytest.raises(CollateralError, match="No matching TCB level"): + validate_tcb_status(tcb_info.tcb_info, tee_tcb_svn, pck_ext) + + def test_fmspc_mismatch_fails(self): + """Test that FMSPC mismatch raises error.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + # Create extensions with different FMSPC + pck_ext = PckExtensions( + ppid="00" * 16, + tcb=PckCertTCB( + pce_svn=13, + cpu_svn=bytes([3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + tcb_components=[3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ), + pceid="0000", + fmspc="aabbcc000000", # Different FMSPC + ) + tee_tcb_svn = bytes([5, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + with pytest.raises(CollateralError, match="FMSPC mismatch"): + validate_tcb_status(tcb_info.tcb_info, tee_tcb_svn, pck_ext) + + def test_pceid_mismatch_fails(self): + """Test that PCE_ID mismatch raises error.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + # Create extensions with different PCE_ID + pck_ext = PckExtensions( + ppid="00" * 16, + tcb=PckCertTCB( + pce_svn=13, + cpu_svn=bytes([3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + tcb_components=[3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ), + pceid="1234", # Different PCE_ID + fmspc="90c06f000000", + ) + tee_tcb_svn = bytes([5, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + with pytest.raises(CollateralError, match="PCE_ID mismatch"): + validate_tcb_status(tcb_info.tcb_info, tee_tcb_svn, pck_ext) + + +class TestValidateQeIdentity: + """Test QE identity validation.""" + + def _create_qe_identity(self) -> EnclaveIdentity: + """Create a sample QE Identity for testing.""" + return EnclaveIdentity( + id="TD_QE", + version=2, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + tcb_evaluation_data_number=17, + miscselect=b"\x00\x00\x00\x00", + miscselect_mask=b"\xff\xff\xff\xff", + attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + attributes_mask=b"\xfb\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00", + mrsigner=b"\xdc" * 32, + isv_prod_id=1, + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=8, + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ], + ) + + def test_valid_qe_identity(self): + """Test validation passes with matching QE identity.""" + qe_identity = self._create_qe_identity() + result = validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=8, + qe_report_mrsigner=b"\xdc" * 32, + qe_report_miscselect=b"\x00\x00\x00\x00", + qe_report_attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + qe_report_isvprodid=1, + ) + assert result.tcb_status == TcbStatus.UP_TO_DATE + + def test_mrsigner_mismatch(self): + """Test validation fails with MRSIGNER mismatch.""" + qe_identity = self._create_qe_identity() + with pytest.raises(CollateralError, match="MRSIGNER does not match"): + validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=8, + qe_report_mrsigner=b"\xaa" * 32, # Wrong MRSIGNER + qe_report_miscselect=b"\x00\x00\x00\x00", + qe_report_attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + qe_report_isvprodid=1, + ) + + def test_miscselect_mismatch(self): + """Test validation fails with MISCSELECT mismatch under mask.""" + qe_identity = self._create_qe_identity() + with pytest.raises(CollateralError, match="MISCSELECT does not match"): + validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=8, + qe_report_mrsigner=b"\xdc" * 32, + qe_report_miscselect=b"\x01\x00\x00\x00", # Wrong MISCSELECT + qe_report_attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + qe_report_isvprodid=1, + ) + + def test_attributes_mismatch(self): + """Test validation fails with Attributes mismatch under mask.""" + qe_identity = self._create_qe_identity() + with pytest.raises(CollateralError, match="Attributes do not match"): + validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=8, + qe_report_mrsigner=b"\xdc" * 32, + qe_report_miscselect=b"\x00\x00\x00\x00", + qe_report_attributes=b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", # Wrong + qe_report_isvprodid=1, + ) + + def test_isvprodid_mismatch(self): + """Test validation fails with ISV ProdID mismatch.""" + qe_identity = self._create_qe_identity() + with pytest.raises(CollateralError, match="ISV ProdID does not match"): + validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=8, + qe_report_mrsigner=b"\xdc" * 32, + qe_report_miscselect=b"\x00\x00\x00\x00", + qe_report_attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + qe_report_isvprodid=2, # Wrong ISV ProdID + ) + + def test_isv_svn_too_low(self): + """Test validation fails when ISV SVN is too low.""" + qe_identity = self._create_qe_identity() + with pytest.raises(CollateralError, match="No matching QE TCB level"): + validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=5, # Below required 8 + qe_report_mrsigner=b"\xdc" * 32, + qe_report_miscselect=b"\x00\x00\x00\x00", + qe_report_attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + qe_report_isvprodid=1, + ) + + +class TestTdxModuleIdentity: + """Test TDX module identity validation.""" + + def _create_tcb_info_with_module_identities(self) -> TcbInfo: + """Create TCB Info with module identities for testing.""" + return TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + fmspc="00a06f000000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=17, + tdx_module=None, + tdx_module_identities=[ + TdxModuleIdentity( + id="TDX_03", + mrsigner=b"\xaa" * 48, + attributes=b"\x00\x00\x00\x00\x00\x00\x00\x00", + attributes_mask=b"\xff\xff\xff\xff\xff\xff\xff\xff", + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=3, # Minor version 3 + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ], + ), + ], + tcb_levels=[], + ) + + def test_get_module_identity(self): + """Test finding module identity by TEE_TCB_SVN.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + # Major version 3 should match TDX_03 + tee_tcb_svn = bytes([5, 3] + [0] * 14) # minor=5, major=3 + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_03" + + def test_get_module_identity_not_found(self): + """Test unknown module version raises error.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + # Major version 5 should not match any (only TDX_03 exists) + tee_tcb_svn = bytes([0, 5] + [0] * 14) # minor=0, major=5 + with pytest.raises(CollateralError, match="Unknown TDX module version TDX_05"): + get_tdx_module_identity(tcb_info, tee_tcb_svn) + + def test_validate_module_identity_success(self): + """Test successful module identity validation.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + # minor=5 (used for TCB level matching), major=3 (used to find TDX_03) + tee_tcb_svn = bytes([5, 3] + [0] * 14) + mr_signer_seam = b"\xaa" * 48 + seam_attributes = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + # Should not raise - validation succeeds + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_mrsigner_mismatch(self): + """Test module identity validation fails with MR_SIGNER_SEAM mismatch.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + tee_tcb_svn = bytes([5, 3] + [0] * 14) # minor=5, major=3 + mr_signer_seam = b"\xbb" * 48 # Wrong MR_SIGNER_SEAM + seam_attributes = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + with pytest.raises(CollateralError, match="MR_SIGNER_SEAM does not match"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_unknown_version_raises(self): + """Test validation raises when module version is unknown.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + # Major version 5 doesn't exist in module identities (only TDX_03) + tee_tcb_svn = bytes([0, 5] + [0] * 14) # minor=0, major=5 + mr_signer_seam = b"\xaa" * 48 + seam_attributes = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + with pytest.raises(CollateralError, match="Unknown TDX module version TDX_05"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_no_matching_tcb_level(self): + """Test validation raises when minor version is too low for any TCB level.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + # Module identity TDX_03 has TCB level with isv_svn=3 + # minor=2 < 3, so no TCB level should match + tee_tcb_svn = bytes([2, 3] + [0] * 14) # minor=2, major=3 + mr_signer_seam = b"\xaa" * 48 + seam_attributes = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + with pytest.raises(CollateralError, match="Could not find a TDX Module Identity TCB Level"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + +class TestCheckCollateralFreshness: + """Test collateral freshness checking.""" + + def test_fresh_collateral_passes(self): + """Test fresh collateral passes.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + ) + + # Mock datetime to a time before sample data's nextUpdate (2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + # Should not raise (sample has tcbEvaluationDataNumber=18) + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=18) + + def test_expired_tcb_info_fails(self): + """Test expired TCB Info fails.""" + # Modify the sample to have an expired date + expired_json = SAMPLE_TCB_INFO_JSON.replace( + '"nextUpdate": "2026-01-16T06:24:56Z"', + '"nextUpdate": "2020-01-16T06:24:56Z"' + ) + tcb_info = parse_tcb_info_response(expired_json.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=expired_json.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + ) + + with pytest.raises(CollateralError, match="TCB Info has expired"): + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=18) + + def test_tcb_evaluation_data_number_threshold_passes(self): + """Test collateral passes when tcbEvaluationDataNumber meets threshold.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + ) + + # Mock datetime to a time before sample data's nextUpdate (2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + # Sample data has tcbEvaluationDataNumber=18, threshold of 18 should pass + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=18) + + # Lower threshold should also pass + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=10) + + def test_tcb_info_evaluation_data_number_below_threshold(self): + """Test failure when TCB Info tcbEvaluationDataNumber is below threshold.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + ) + + # Mock datetime to a time before sample data's nextUpdate (2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + # Sample data has tcbEvaluationDataNumber=18, threshold of 19 should fail + with pytest.raises(CollateralError, match="TCB Info tcbEvaluationDataNumber .* is below"): + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=19) + + def test_qe_identity_evaluation_data_number_below_threshold(self): + """Test failure when QE Identity tcbEvaluationDataNumber is below threshold.""" + # Create TCB Info with high tcbEvaluationDataNumber + high_eval_tcb_json = SAMPLE_TCB_INFO_JSON.replace( + '"tcbEvaluationDataNumber": 18', + '"tcbEvaluationDataNumber": 25' + ) + tcb_info = parse_tcb_info_response(high_eval_tcb_json.encode()) + # QE Identity still has tcbEvaluationDataNumber=18 + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=high_eval_tcb_json.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + ) + + # Mock datetime to a time before sample data's nextUpdate (2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + # TCB Info has 25, QE Identity has 18, threshold of 20 should fail on QE Identity + with pytest.raises(CollateralError, match="QE Identity tcbEvaluationDataNumber .* is below"): + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=20) + + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + +class TestHelperFunctions: + """Test helper functions.""" + + def test_parse_datetime(self): + """Test datetime parsing.""" + result = _parse_datetime("2025-12-17T06:24:56Z") + assert result.year == 2025 + assert result.month == 12 + assert result.day == 17 + assert result.tzinfo is not None + + def test_parse_hex_bytes(self): + """Test hex string parsing.""" + result = _parse_hex_bytes("aabbccdd") + assert result == bytes([0xaa, 0xbb, 0xcc, 0xdd]) + + +class TestTcbStatus: + """Test TcbStatus enum.""" + + def test_status_values(self): + """Test all status values can be created.""" + assert TcbStatus.UP_TO_DATE == "UpToDate" + assert TcbStatus.OUT_OF_DATE == "OutOfDate" + assert TcbStatus.REVOKED == "Revoked" + assert TcbStatus.SW_HARDENING_NEEDED == "SWHardeningNeeded" + + +# ============================================================================= +# Caching Tests +# ============================================================================= + +class TestCacheHelpers: + """Test cache helper functions.""" + + def test_tcb_info_cache_path(self): + """Test TCB Info cache path generation.""" + path = _get_tcb_info_cache_path("00A06D080000") + assert "tdx_tcb_info_00a06d080000.json" in path + # Should be lowercase + path2 = _get_tcb_info_cache_path("00a06d080000") + assert path == path2 + + def test_qe_identity_cache_path(self): + """Test QE Identity cache path generation.""" + path = _get_qe_identity_cache_path() + assert "tdx_qe_identity.json" in path + + def test_is_tcb_info_fresh(self): + """Test TCB Info freshness check.""" + # Fresh: next_update is in the future + fresh_tcb_info = TdxTcbInfo( + tcb_info=TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc) - timedelta(days=1), + next_update=datetime.now(timezone.utc) + timedelta(days=29), + fmspc="00A06D080000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=1, + tdx_module=None, + tdx_module_identities=[], + tcb_levels=[], + ), + signature="", + ) + assert _is_tcb_info_fresh(fresh_tcb_info) is True + + # Stale: next_update is in the past + stale_tcb_info = TdxTcbInfo( + tcb_info=TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc) - timedelta(days=31), + next_update=datetime.now(timezone.utc) - timedelta(days=1), + fmspc="00A06D080000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=1, + tdx_module=None, + tdx_module_identities=[], + tcb_levels=[], + ), + signature="", + ) + assert _is_tcb_info_fresh(stale_tcb_info) is False + + def test_is_qe_identity_fresh(self): + """Test QE Identity freshness check.""" + # Fresh + fresh_qe = QeIdentity( + enclave_identity=EnclaveIdentity( + id="TD_QE", + version=2, + issue_date=datetime.now(timezone.utc) - timedelta(days=1), + next_update=datetime.now(timezone.utc) + timedelta(days=29), + tcb_evaluation_data_number=1, + miscselect=b"\x00" * 4, + miscselect_mask=b"\xff" * 4, + attributes=b"\x00" * 16, + attributes_mask=b"\xff" * 16, + mrsigner=b"\xaa" * 32, + isv_prod_id=1, + tcb_levels=[], + ), + signature="", + ) + assert _is_qe_identity_fresh(fresh_qe) is True + + # Stale + stale_qe = QeIdentity( + enclave_identity=EnclaveIdentity( + id="TD_QE", + version=2, + issue_date=datetime.now(timezone.utc) - timedelta(days=31), + next_update=datetime.now(timezone.utc) - timedelta(days=1), + tcb_evaluation_data_number=1, + miscselect=b"\x00" * 4, + miscselect_mask=b"\xff" * 4, + attributes=b"\x00" * 16, + attributes_mask=b"\xff" * 16, + mrsigner=b"\xaa" * 32, + isv_prod_id=1, + tcb_levels=[], + ), + signature="", + ) + assert _is_qe_identity_fresh(stale_qe) is False + + +class TestIssuerChainParsing: + """Test issuer chain header parsing.""" + + def test_parse_issuer_chain_missing_certs(self): + """Test that parsing fails with too few certificates.""" + # Single certificate is not enough (need signing + root at minimum) + single_cert_pem = """-----BEGIN CERTIFICATE----- +MIICjzCCAjSgAwIBAgIUImUM1lqdNInzg7SVUr9QGzknBqwwCgYIKoZIzj0EAwIw +aDEaMBgGA1UEAwwRSW50ZWwgU0dYIFJvb3QgQ0ExGjAYBgNVBAoMEUludGVsIENv +cnBvcmF0aW9uMRQwEgYDVQQHDAtTYW50YSBDbGFyYTELMAkGA1UECAwCQ0ExCzAJ +BgNVBAYTAlVTMB4XDTE4MDUyMTEwNDUxMFoXDTQ5MTIzMTIzNTk1OVowaDEaMBgG +A1UEAwwRSW50ZWwgU0dYIFJvb3QgQ0ExGjAYBgNVBAoMEUludGVsIENvcnBvcmF0 +aW9uMRQwEgYDVQQHDAtTYW50YSBDbGFyYTELMAkGA1UECAwCQ0ExCzAJBgNVBAYT +AlVTMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEC6nEwMDIYZOj/iPWsCzaEKi7 +1OiOSLRFhWGjbnBVJfVnkY4u3IjkDYYL0MxO4mqsyYjlBalTVYxFP2sJBK5zlKOB +uzCBuDAfBgNVHSMEGDAWgBQiZQzWWp00ifODtJVSv1AbOScGrDBSBgNVHR8ESzBJ +MEegRaBDhkFodHRwczovL2NlcnRpZmljYXRlcy50cnVzdGVkc2VydmljZXMuaW50 +ZWwuY29tL0ludGVsU0dYUm9vdENBLmRlcjAdBgNVHQ4EFgQUImUM1lqdNInzg7SV +Ur9QGzknBqwwDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQEwCgYI +KoZIzj0EAwIDSQAwRgIhAOW/5QkR+S9CiSDcNoowLuPRLsWGf/Yi7GSX94BgwTwg +AiEA4J0lrHoMs+Xo5o/sX6O9QWxHRAvZUGOdRQ7cvqRXaqI= +-----END CERTIFICATE-----""" + with pytest.raises(CollateralError, match="at least 2 certificates"): + _parse_issuer_chain_header(single_cert_pem) + + def test_parse_issuer_chain_invalid_pem(self): + """Test that parsing fails with invalid PEM data.""" + with pytest.raises(CollateralError, match="Failed to parse"): + _parse_issuer_chain_header("not-valid-pem-data") + + +class TestCollateralSignatureVerification: + """Test collateral signature verification.""" + + def test_verify_signature_missing_key(self): + """Test that verification fails when JSON key is missing.""" + json_data = b'{"other": "data"}' + with pytest.raises(CollateralError, match="does not contain"): + _verify_collateral_signature( + json_bytes=json_data, + json_key="tcbInfo", + signature_hex="00" * 64, + signing_cert=MagicMock(), + data_name="Test", + ) + + def test_verify_signature_invalid_hex(self): + """Test that verification fails with invalid signature hex.""" + json_data = b'{"tcbInfo": {}}' + with pytest.raises(CollateralError, match="not valid hex"): + _verify_collateral_signature( + json_bytes=json_data, + json_key="tcbInfo", + signature_hex="not-hex", + signing_cert=MagicMock(), + data_name="Test", + ) + + def test_verify_signature_wrong_length(self): + """Test that verification fails with wrong signature length.""" + json_data = b'{"tcbInfo": {}}' + with pytest.raises(CollateralError, match="expected 64"): + _verify_collateral_signature( + json_bytes=json_data, + json_key="tcbInfo", + signature_hex="00" * 32, # 32 bytes instead of 64 + signing_cert=MagicMock(), + data_name="Test", + ) + + +class TestFetchTcbInfoWithCache: + """Test fetch_tcb_info caching behavior.""" + + def test_cache_miss_fetches_from_network(self, tmp_path): + """Test that cache miss triggers network fetch.""" + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + with patch('tinfoil.attestation.collateral_tdx.verify_tcb_info_signature'): + mock_response = MagicMock() + mock_response.content = SAMPLE_TCB_INFO_RESPONSE + mock_response.raise_for_status = MagicMock() + mock_response.headers = {"TCB-Info-Issuer-Chain": "dummy"} + mock_get.return_value = mock_response + + # Also mock _parse_issuer_chain_header since we have dummy header + with patch('tinfoil.attestation.collateral_tdx._parse_issuer_chain_header') as mock_parse: + mock_parse.return_value = [] + + tcb_info, raw, _chain = fetch_tcb_info("00A06D080000") + + mock_get.assert_called_once() + assert tcb_info.tcb_info.id == "TDX" + + def test_cache_hit_skips_network(self, tmp_path): + """Test that fresh cache hit skips network fetch.""" + import base64 + import json as json_mod + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + # Write fresh cache in the new JSON format with issuer chain + cache_path = tmp_path / "tdx_tcb_info_00a06d080000.json" + cache_data = { + "body": base64.b64encode(SAMPLE_TCB_INFO_RESPONSE).decode("ascii"), + "issuer_chain_pem": "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----", + } + cache_path.write_text(json_mod.dumps(cache_data)) + + # Mock datetime so sample data appears fresh (nextUpdate is 2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + # Mock signature verification on cache hit + with patch('tinfoil.attestation.collateral_tdx.verify_tcb_info_signature'): + with patch('tinfoil.attestation.collateral_tdx.parse_pem_chain') as mock_pem: + mock_pem.return_value = [] + tcb_info, raw, _chain = fetch_tcb_info("00A06D080000") + + # Should not call network + mock_get.assert_not_called() + assert tcb_info.tcb_info.id == "TDX" + + def test_stale_cache_fetches_fresh(self, tmp_path): + """Test that stale cache triggers fresh fetch.""" + import base64 + import json as json_mod + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + # Write stale cache (expired next_update) in new JSON format + cache_path = tmp_path / "tdx_tcb_info_00a06d080000.json" + cache_data = { + "body": base64.b64encode(SAMPLE_STALE_TCB_INFO_RESPONSE).decode("ascii"), + "issuer_chain_pem": "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----", + } + cache_path.write_text(json_mod.dumps(cache_data)) + + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + with patch('tinfoil.attestation.collateral_tdx.verify_tcb_info_signature'): + mock_response = MagicMock() + mock_response.content = SAMPLE_TCB_INFO_RESPONSE + mock_response.raise_for_status = MagicMock() + mock_response.headers = {"TCB-Info-Issuer-Chain": "dummy"} + mock_get.return_value = mock_response + + with patch('tinfoil.attestation.collateral_tdx._parse_issuer_chain_header') as mock_parse: + mock_parse.return_value = [] + + tcb_info, raw, _chain = fetch_tcb_info("00A06D080000") + + # Should call network because cache is stale + mock_get.assert_called_once() + + +class TestFetchQeIdentityWithCache: + """Test fetch_qe_identity caching behavior.""" + + def test_cache_miss_fetches_from_network(self, tmp_path): + """Test that cache miss triggers network fetch.""" + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + with patch('tinfoil.attestation.collateral_tdx.verify_qe_identity_signature'): + mock_response = MagicMock() + mock_response.content = SAMPLE_QE_IDENTITY_RESPONSE + mock_response.raise_for_status = MagicMock() + mock_response.headers = {"SGX-Enclave-Identity-Issuer-Chain": "dummy"} + mock_get.return_value = mock_response + + with patch('tinfoil.attestation.collateral_tdx._parse_issuer_chain_header') as mock_parse: + mock_parse.return_value = [] + + qe_identity, raw, _chain = fetch_qe_identity() + + mock_get.assert_called_once() + assert qe_identity.enclave_identity.id == "TD_QE" + + def test_cache_hit_skips_network(self, tmp_path): + """Test that fresh cache hit skips network fetch.""" + import base64 + import json as json_mod + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + # Write fresh cache in the new JSON format with issuer chain + cache_path = tmp_path / "tdx_qe_identity.json" + cache_data = { + "body": base64.b64encode(SAMPLE_QE_IDENTITY_RESPONSE).decode("ascii"), + "issuer_chain_pem": "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----", + } + cache_path.write_text(json_mod.dumps(cache_data)) + + # Mock datetime so sample data appears fresh (nextUpdate is 2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + # Mock signature verification on cache hit + with patch('tinfoil.attestation.collateral_tdx.verify_qe_identity_signature'): + with patch('tinfoil.attestation.collateral_tdx.parse_pem_chain') as mock_pem: + mock_pem.return_value = [] + qe_identity, raw, _chain = fetch_qe_identity() + + # Should not call network + mock_get.assert_not_called() + assert qe_identity.enclave_identity.id == "TD_QE" + + +# ============================================================================= +# CRL Tests +# ============================================================================= + +from tinfoil.attestation.collateral_tdx import ( + PckCrl, + _get_crl_cache_path, + _is_crl_fresh, + _determine_pck_ca_type, + fetch_pck_crl, + validate_certificate_revocation, +) +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec + + +class TestCrlCacheHelpers: + """Test CRL cache helper functions.""" + + def test_crl_cache_path_platform(self): + """Test CRL cache path generation for platform CA.""" + path = _get_crl_cache_path("platform") + assert "tdx_pck_crl_platform.json" in path + + def test_crl_cache_path_processor(self): + """Test CRL cache path generation for processor CA.""" + path = _get_crl_cache_path("processor") + assert "tdx_pck_crl_processor.json" in path + + def test_crl_cache_path_lowercase(self): + """Test CRL cache path is lowercase.""" + path1 = _get_crl_cache_path("Platform") + path2 = _get_crl_cache_path("platform") + assert path1 == path2 + + def test_is_crl_fresh_valid(self): + """Test CRL freshness check with valid (not expired) CRL.""" + # Create a mock CRL with next_update in the future + mock_crl = MagicMock() + mock_crl.next_update_utc = datetime.now(timezone.utc) + timedelta(days=30) + + assert _is_crl_fresh(mock_crl) is True + + def test_is_crl_fresh_expired(self): + """Test CRL freshness check with expired CRL.""" + mock_crl = MagicMock() + mock_crl.next_update_utc = datetime.now(timezone.utc) - timedelta(days=1) + + assert _is_crl_fresh(mock_crl) is False + + def test_is_crl_fresh_no_next_update(self): + """Test CRL freshness check with no next_update.""" + mock_crl = MagicMock() + mock_crl.next_update_utc = None + + assert _is_crl_fresh(mock_crl) is False + + +class TestDeterminePckCaType: + """Test PCK CA type determination from certificate issuer.""" + + def _create_mock_cert_with_issuer(self, cn: str) -> MagicMock: + """Create a mock certificate with given issuer CN.""" + mock_attr = MagicMock() + mock_attr.oid = NameOID.COMMON_NAME + mock_attr.value = cn + + mock_issuer = MagicMock() + mock_issuer.__iter__ = MagicMock(return_value=iter([mock_attr])) + + mock_cert = MagicMock() + mock_cert.issuer = mock_issuer + + return mock_cert + + def test_platform_ca(self): + """Test detecting Platform CA from issuer CN.""" + mock_cert = self._create_mock_cert_with_issuer("Intel SGX PCK Platform CA") + result = _determine_pck_ca_type(mock_cert) + assert result == "platform" + + def test_processor_ca(self): + """Test detecting Processor CA from issuer CN.""" + mock_cert = self._create_mock_cert_with_issuer("Intel SGX PCK Processor CA") + result = _determine_pck_ca_type(mock_cert) + assert result == "processor" + + def test_unknown_ca_raises_error(self): + """Test that unknown CA type raises error.""" + mock_cert = self._create_mock_cert_with_issuer("Unknown CA") + with pytest.raises(CollateralError, match="Could not determine PCK CA type"): + _determine_pck_ca_type(mock_cert) + + +class TestFetchPckCrl: + """Test fetch_pck_crl function.""" + + def test_invalid_ca_type_raises_error(self): + """Test that invalid CA type raises error.""" + with pytest.raises(CollateralError, match="Invalid CA type"): + fetch_pck_crl("invalid") + + def test_cache_hit_skips_network(self, tmp_path): + """Test that fresh cache hit skips network fetch.""" + import base64 + import json as json_mod + # Create a mock CRL + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.hazmat.primitives import hashes + from cryptography import x509 + from cryptography.x509 import CertificateRevocationListBuilder + import datetime as dt_module + + # Generate a key for signing + private_key = ec.generate_private_key(ec.SECP256R1()) + + # Build a minimal CRL + builder = CertificateRevocationListBuilder() + builder = builder.issuer_name(x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "Test CA"), + ])) + builder = builder.last_update(dt_module.datetime.now(dt_module.timezone.utc)) + builder = builder.next_update(dt_module.datetime.now(dt_module.timezone.utc) + dt_module.timedelta(days=30)) + + crl = builder.sign(private_key, hashes.SHA256()) + # Use PEM encoding - fetch_pck_crl expects PEM format + crl_pem = crl.public_bytes(serialization.Encoding.PEM) + + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + # Write fresh CRL cache in new JSON format (body is PEM bytes) + cache_path = tmp_path / "tdx_pck_crl_platform.json" + cache_data = { + "body": base64.b64encode(crl_pem).decode("ascii"), + "issuer_chain_pem": "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----", + } + cache_path.write_text(json_mod.dumps(cache_data)) + + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + # Mock signature verification on cache hit + with patch('tinfoil.attestation.collateral_tdx._verify_crl_signature'): + with patch('tinfoil.attestation.collateral_tdx.parse_pem_chain') as mock_pem: + mock_pem.return_value = [] + result = fetch_pck_crl("platform") + + # Should not call network + mock_get.assert_not_called() + assert result.ca_type == "platform" + + def test_cache_miss_fetches_from_network(self, tmp_path): + """Test that cache miss triggers network fetch.""" + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.hazmat.primitives import hashes, serialization + from cryptography import x509 + from cryptography.x509 import CertificateRevocationListBuilder + import datetime as dt_module + + # Generate keys + private_key = ec.generate_private_key(ec.SECP256R1()) + + # Build a minimal CRL + builder = CertificateRevocationListBuilder() + builder = builder.issuer_name(x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "Test CA"), + ])) + builder = builder.last_update(dt_module.datetime.now(dt_module.timezone.utc)) + builder = builder.next_update(dt_module.datetime.now(dt_module.timezone.utc) + dt_module.timedelta(days=30)) + + crl = builder.sign(private_key, hashes.SHA256()) + # Use PEM encoding - fetch_pck_crl expects PEM format from Intel PCS + crl_pem = crl.public_bytes(serialization.Encoding.PEM) + + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + with patch('tinfoil.attestation.collateral_tdx._verify_crl_signature'): + with patch('tinfoil.attestation.collateral_tdx._parse_issuer_chain_header') as mock_parse: + mock_parse.return_value = [] + + mock_response = MagicMock() + mock_response.content = crl_pem + mock_response.raise_for_status = MagicMock() + mock_response.headers = {"SGX-PCK-CRL-Issuer-Chain": "dummy"} + mock_get.return_value = mock_response + + result = fetch_pck_crl("platform") + + mock_get.assert_called_once() + assert result.ca_type == "platform" + + +class TestValidateCertificateRevocation: + """Test certificate revocation validation.""" + + def _create_mock_crl(self, revoked_serials: list[int] = None) -> MagicMock: + """Create a mock CRL with optional revoked serials.""" + mock_crl = MagicMock() + mock_crl.next_update_utc = datetime.now(timezone.utc) + timedelta(days=30) + + def get_revoked(serial): + if revoked_serials and serial in revoked_serials: + revoked = MagicMock() + revoked.revocation_date_utc = datetime.now(timezone.utc) - timedelta(days=1) + return revoked + return None + + mock_crl.get_revoked_certificate_by_serial_number = get_revoked + return mock_crl + + def _create_mock_cert(self, serial: int) -> MagicMock: + """Create a mock certificate with given serial number.""" + mock_cert = MagicMock() + mock_cert.serial_number = serial + return mock_cert + + def test_certificate_not_revoked(self): + """Test that non-revoked certificate passes validation.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_crl = self._create_mock_crl(revoked_serials=[]) + pck_crl = PckCrl( + crl=mock_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + ) + + mock_cert = self._create_mock_cert(serial=12345) + + # Should not raise + validate_certificate_revocation(collateral, mock_cert) + + def test_certificate_revoked(self): + """Test that revoked certificate fails validation.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_crl = self._create_mock_crl(revoked_serials=[12345]) + pck_crl = PckCrl( + crl=mock_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + ) + + mock_cert = self._create_mock_cert(serial=12345) + + with pytest.raises(CollateralError, match="has been revoked"): + validate_certificate_revocation(collateral, mock_cert) + + def test_no_crl_raises_error(self): + """Test that missing CRL raises error.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=None, # No CRL + ) + + mock_cert = self._create_mock_cert(serial=12345) + + with pytest.raises(CollateralError, match="CRL not available"): + validate_certificate_revocation(collateral, mock_cert) + + def test_expired_crl_raises_error(self): + """Test that expired CRL raises error.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_crl = MagicMock() + mock_crl.next_update_utc = datetime.now(timezone.utc) - timedelta(days=1) # Expired + + pck_crl = PckCrl( + crl=mock_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) - timedelta(days=1), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + ) + + mock_cert = self._create_mock_cert(serial=12345) + + with pytest.raises(CollateralError, match="CRL has expired"): + validate_certificate_revocation(collateral, mock_cert) + + def test_intermediate_cert_not_revoked(self): + """Test that non-revoked intermediate CA certificate passes validation.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_pck_crl = self._create_mock_crl(revoked_serials=[]) + pck_crl = PckCrl( + crl=mock_pck_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + mock_root_crl = self._create_mock_crl(revoked_serials=[]) + from tinfoil.attestation.collateral_tdx import RootCrl + root_crl = RootCrl( + crl=mock_root_crl, + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + root_crl=root_crl, + ) + + mock_pck_cert = self._create_mock_cert(serial=12345) + mock_intermediate_cert = self._create_mock_cert(serial=67890) + + # Should not raise + validate_certificate_revocation(collateral, mock_pck_cert, mock_intermediate_cert) + + def test_intermediate_cert_revoked(self): + """Test that revoked intermediate CA certificate fails validation.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_pck_crl = self._create_mock_crl(revoked_serials=[]) + pck_crl = PckCrl( + crl=mock_pck_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + # Root CRL has the intermediate CA's serial as revoked + mock_root_crl = self._create_mock_crl(revoked_serials=[67890]) + from tinfoil.attestation.collateral_tdx import RootCrl + root_crl = RootCrl( + crl=mock_root_crl, + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + root_crl=root_crl, + ) + + mock_pck_cert = self._create_mock_cert(serial=12345) + mock_intermediate_cert = self._create_mock_cert(serial=67890) + + with pytest.raises(CollateralError, match="Intermediate CA certificate has been revoked"): + validate_certificate_revocation(collateral, mock_pck_cert, mock_intermediate_cert) + + def test_no_root_crl_with_intermediate_raises_error(self): + """Test that missing root CRL raises error when checking intermediate.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_pck_crl = self._create_mock_crl(revoked_serials=[]) + pck_crl = PckCrl( + crl=mock_pck_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + root_crl=None, # No root CRL + ) + + mock_pck_cert = self._create_mock_cert(serial=12345) + mock_intermediate_cert = self._create_mock_cert(serial=67890) + + with pytest.raises(CollateralError, match="Root CRL not available"): + validate_certificate_revocation(collateral, mock_pck_cert, mock_intermediate_cert) + + def test_expired_root_crl_raises_error(self): + """Test that expired root CRL raises error when checking intermediate.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_pck_crl = self._create_mock_crl(revoked_serials=[]) + pck_crl = PckCrl( + crl=mock_pck_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + mock_root_crl = MagicMock() + mock_root_crl.next_update_utc = datetime.now(timezone.utc) - timedelta(days=1) # Expired + from tinfoil.attestation.collateral_tdx import RootCrl + root_crl = RootCrl( + crl=mock_root_crl, + next_update=datetime.now(timezone.utc) - timedelta(days=1), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + root_crl=root_crl, + ) + + mock_pck_cert = self._create_mock_cert(serial=12345) + mock_intermediate_cert = self._create_mock_cert(serial=67890) + + with pytest.raises(CollateralError, match="Root CA CRL has expired"): + validate_certificate_revocation(collateral, mock_pck_cert, mock_intermediate_cert) + + +# ============================================================================= +# TCB Evaluation Data Numbers Tests +# ============================================================================= + +from tinfoil.attestation.collateral_tdx import ( + TcbEvalNumber, + TcbEvaluationDataNumbers, + _parse_tcb_eval_numbers_response, + fetch_tcb_evaluation_data_numbers, + calculate_min_tcb_evaluation_data_number, +) + + +SAMPLE_TCB_EVAL_NUMBERS_RESPONSE = b'''{ + "tcbEvaluationDataNumbers": { + "id": "TDX", + "version": 1, + "issueDate": "2026-01-15T19:17:02Z", + "nextUpdate": "2026-02-14T19:17:02Z", + "tcbEvalNumbers": [ + {"tcbEvaluationDataNumber": 20, "tcbRecoveryEventDate": "2025-08-12T00:00:00Z", "tcbDate": "2025-08-13T00:00:00Z"}, + {"tcbEvaluationDataNumber": 19, "tcbRecoveryEventDate": "2025-05-13T00:00:00Z", "tcbDate": "2025-05-14T00:00:00Z"}, + {"tcbEvaluationDataNumber": 18, "tcbRecoveryEventDate": "2024-11-12T00:00:00Z", "tcbDate": "2024-11-13T00:00:00Z"}, + {"tcbEvaluationDataNumber": 17, "tcbRecoveryEventDate": "2024-03-12T00:00:00Z", "tcbDate": "2024-03-13T00:00:00Z"}, + {"tcbEvaluationDataNumber": 16, "tcbRecoveryEventDate": "2023-08-08T00:00:00Z", "tcbDate": "2023-08-09T00:00:00Z"} + ] + }, + "signature": "dummy" +}''' + + +class TestParseTcbEvalNumbersResponse: + """Test parsing of tcbevaluationdatanumbers response.""" + + def test_parse_valid_response(self): + """Test parsing a valid response.""" + result = _parse_tcb_eval_numbers_response(SAMPLE_TCB_EVAL_NUMBERS_RESPONSE) + + assert result.id == "TDX" + assert result.version == 1 + assert len(result.tcb_eval_numbers) == 5 + assert result.tcb_eval_numbers[0].tcb_evaluation_data_number == 20 + assert result.tcb_eval_numbers[4].tcb_evaluation_data_number == 16 + + def test_parse_dates(self): + """Test that dates are parsed correctly.""" + result = _parse_tcb_eval_numbers_response(SAMPLE_TCB_EVAL_NUMBERS_RESPONSE) + + # Check first entry + assert result.tcb_eval_numbers[0].tcb_recovery_event_date.year == 2025 + assert result.tcb_eval_numbers[0].tcb_recovery_event_date.month == 8 + assert result.tcb_eval_numbers[0].tcb_recovery_event_date.day == 12 + + def test_parse_invalid_json(self): + """Test that invalid JSON raises error.""" + with pytest.raises(CollateralError, match="Failed to parse"): + _parse_tcb_eval_numbers_response(b"not json") + + +class TestFetchTcbEvaluationDataNumbers: + """Test fetching TCB evaluation data numbers.""" + + def test_fetch_success(self): + """Test successful fetch with signature verification mocked out.""" + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get, \ + patch('tinfoil.attestation.collateral_tdx._parse_issuer_chain_header') as mock_parse, \ + patch('tinfoil.attestation.collateral_tdx.verify_tcb_eval_numbers_signature'): + mock_response = MagicMock() + mock_response.content = SAMPLE_TCB_EVAL_NUMBERS_RESPONSE + mock_response.raise_for_status = MagicMock() + mock_response.headers = {"TCB-Evaluation-Data-Numbers-Issuer-Chain": "dummy-chain"} + mock_get.return_value = mock_response + mock_parse.return_value = [] + + result = fetch_tcb_evaluation_data_numbers() + + mock_get.assert_called_once() + assert result.id == "TDX" + + def test_fetch_missing_issuer_chain_header(self): + """Test that missing issuer chain header raises error.""" + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + mock_response = MagicMock() + mock_response.content = SAMPLE_TCB_EVAL_NUMBERS_RESPONSE + mock_response.raise_for_status = MagicMock() + mock_response.headers = {} + mock_get.return_value = mock_response + + with pytest.raises(CollateralError, match="missing TCB-Evaluation-Data-Numbers-Issuer-Chain"): + fetch_tcb_evaluation_data_numbers() + + def test_fetch_failure(self): + """Test fetch failure raises error.""" + import requests as req + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + mock_get.side_effect = req.RequestException("Network error") + + with pytest.raises(CollateralError, match="HTTP GET"): + fetch_tcb_evaluation_data_numbers() + + +class TestCalculateMinTcbEvaluationDataNumber: + """Test calculate_min_tcb_evaluation_data_number function.""" + + def test_calculate_with_recent_cutoff(self): + """Test calculation finds correct minimum with recent cutoff.""" + with patch('tinfoil.attestation.collateral_tdx.fetch_tcb_evaluation_data_numbers') as mock_fetch: + mock_fetch.return_value = _parse_tcb_eval_numbers_response(SAMPLE_TCB_EVAL_NUMBERS_RESPONSE) + + # With 365-day cutoff from 2026-01-15, numbers from 2025-01-15 or later are acceptable + # Number 19 (2025-05-13) should be the minimum acceptable + # Number 18 (2024-11-12) is older than 1 year + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_now = datetime(2026, 1, 15, tzinfo=timezone.utc) + mock_dt.now.return_value = mock_now + mock_dt.side_effect = lambda *args, **kw: datetime(*args, **kw) + + result = calculate_min_tcb_evaluation_data_number(max_age_days=365) + + # Should return 19 since 18 is older than 1 year + assert result == 19 + + def test_calculate_with_shorter_cutoff(self): + """Test calculation with shorter max age.""" + with patch('tinfoil.attestation.collateral_tdx.fetch_tcb_evaluation_data_numbers') as mock_fetch: + mock_fetch.return_value = _parse_tcb_eval_numbers_response(SAMPLE_TCB_EVAL_NUMBERS_RESPONSE) + + # With 180-day cutoff from 2026-01-15, cutoff is ~2025-07-19 + # Number 20 (2025-08-12) should be acceptable + # Number 19 (2025-05-13) is older than 180 days + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_now = datetime(2026, 1, 15, tzinfo=timezone.utc) + mock_dt.now.return_value = mock_now + mock_dt.side_effect = lambda *args, **kw: datetime(*args, **kw) + + result = calculate_min_tcb_evaluation_data_number(max_age_days=180) + + assert result == 20 + + def test_calculate_empty_numbers_raises(self): + """Test that empty numbers list raises error.""" + with patch('tinfoil.attestation.collateral_tdx.fetch_tcb_evaluation_data_numbers') as mock_fetch: + mock_fetch.return_value = TcbEvaluationDataNumbers( + id="TDX", + version=1, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + tcb_eval_numbers=[], + signature="dummy", + ) + + with pytest.raises(CollateralError, match="No TCB evaluation data numbers found"): + calculate_min_tcb_evaluation_data_number() + + def test_calculate_all_too_old_raises(self): + """Test that all numbers too old raises error.""" + # Create response with only old numbers + old_numbers_response = b'''{ + "tcbEvaluationDataNumbers": { + "id": "TDX", + "version": 1, + "issueDate": "2026-01-15T19:17:02Z", + "nextUpdate": "2026-02-14T19:17:02Z", + "tcbEvalNumbers": [ + {"tcbEvaluationDataNumber": 10, "tcbRecoveryEventDate": "2020-01-01T00:00:00Z", "tcbDate": "2020-01-02T00:00:00Z"} + ] + }, + "signature": "dummy" + }''' + + with patch('tinfoil.attestation.collateral_tdx.fetch_tcb_evaluation_data_numbers') as mock_fetch: + mock_fetch.return_value = _parse_tcb_eval_numbers_response(old_numbers_response) + + with pytest.raises(CollateralError, match="older than 365 days"): + calculate_min_tcb_evaluation_data_number() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From ba2c999b038990fd52a11c2fa7614c2302d3d042 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:14:47 -0500 Subject: [PATCH 05/19] feat: add PCK extension parsing with pyasn1 (pck_extensions.py) Parse Intel PCK certificate X.509v3 extensions using pyasn1 for proper ASN.1 decoding. Extracts FMSPC, PCE ID, TCB components, and CPUSVN with strict validation and error handling. --- src/tinfoil/attestation/pck_extensions.py | 273 ++++++++++++++++++++++ tests/test_pck_extensions.py | 144 ++++++++++++ 2 files changed, 417 insertions(+) create mode 100644 src/tinfoil/attestation/pck_extensions.py create mode 100644 tests/test_pck_extensions.py diff --git a/src/tinfoil/attestation/pck_extensions.py b/src/tinfoil/attestation/pck_extensions.py new file mode 100644 index 0000000..acf86c1 --- /dev/null +++ b/src/tinfoil/attestation/pck_extensions.py @@ -0,0 +1,273 @@ +""" +PCK Certificate Extension Parsing for Intel TDX attestation. + +This module extracts Intel SGX-specific extensions from PCK (Provisioning +Certification Key) certificates. These extensions contain critical information +for TCB (Trusted Computing Base) validation: + +- FMSPC: Firmware/Microcode Security Patch Cluster (6 bytes) +- PCEID: PCE ID (2 bytes) +- TCB: TCB components including PCE SVN and CPU SVN + +Intel SGX Extension OID hierarchy: + 1.2.840.113741.1.13.1 - SGX Extension (parent) + 1.2.840.113741.1.13.1.1 - PPID + 1.2.840.113741.1.13.1.2 - TCB + 1.2.840.113741.1.13.1.2.1-16 - TCB Components + 1.2.840.113741.1.13.1.2.17 - PCE SVN + 1.2.840.113741.1.13.1.2.18 - CPU SVN + 1.2.840.113741.1.13.1.3 - PCEID + 1.2.840.113741.1.13.1.4 - FMSPC +""" + +from contextlib import contextmanager +from dataclasses import dataclass + +from cryptography import x509 +from pyasn1.codec.der import decoder as der_decoder + +# Intel SGX OID base +_INTEL_SGX_OID_BASE = "1.2.840.113741.1.13.1" + +# Intel SGX Extension OIDs +OID_SGX_EXTENSION = x509.ObjectIdentifier(_INTEL_SGX_OID_BASE) +OID_PPID = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.1") +OID_TCB = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.2") +OID_PCEID = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.3") +OID_FMSPC = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.4") + +# TCB sub-OIDs +OID_PCE_SVN = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.2.17") +OID_CPU_SVN = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.2.18") + +# TCB component OID → index mapping (1-indexed OID to 0-indexed array position) +_TCB_COMPONENT_OID_INDEX = { + f"{_INTEL_SGX_OID_BASE}.2.{i}": i - 1 for i in range(1, 17) +} + +# Size constants (match go-tdx-guest/pcs/pcs.go) +PCK_CERT_EXTENSION_COUNT = 6 +SGX_EXTENSION_MIN_SIZE = 4 +TCB_EXTENSION_SIZE = 18 # 16 components + PCE SVN + CPU SVN +PPID_SIZE = 16 +CPU_SVN_SIZE = 16 +FMSPC_SIZE = 6 +PCEID_SIZE = 2 +TCB_COMPONENTS_COUNT = 16 + + +class PckExtensionError(Exception): + """Raised when PCK certificate extension parsing fails.""" + pass + + +@contextmanager +def _asn1_errors(label: str): + """Wrap unexpected ASN.1 exceptions as PckExtensionError.""" + try: + yield + except PckExtensionError: + raise + except Exception as e: + raise PckExtensionError(f"Unexpected ASN.1 structure in {label}: {e}") from e + + +@dataclass +class PckCertTCB: + """ + TCB (Trusted Computing Base) information from PCK certificate. + + Attributes: + pce_svn: PCE Security Version Number + cpu_svn: CPU SVN as raw bytes (16 bytes) + tcb_components: Individual TCB component SVNs (16 values) + """ + pce_svn: int + cpu_svn: bytes # 16 bytes + tcb_components: list[int] # 16 components + + def __str__(self) -> str: + return ( + f"PckCertTCB(pce_svn={self.pce_svn}, " + f"cpu_svn={self.cpu_svn.hex()}, " + f"components={self.tcb_components})" + ) + + +@dataclass +class PckExtensions: + """ + Intel SGX extensions extracted from a PCK certificate. + + Attributes: + ppid: Platform/Product ID as hex string (16 bytes) + tcb: TCB information + pceid: PCE ID as hex string (2 bytes) + fmspc: Firmware/Microcode FMSPC as hex string (6 bytes) + """ + ppid: str # 16 bytes hex + tcb: PckCertTCB + pceid: str # 2 bytes hex + fmspc: str # 6 bytes hex + + def __str__(self) -> str: + return ( + f"PckExtensions(fmspc={self.fmspc}, pceid={self.pceid}, " + f"ppid={self.ppid[:8]}...)" + ) + + +def extract_pck_extensions(cert: x509.Certificate) -> PckExtensions: + """ + Extract Intel SGX extensions from a PCK certificate. + + Args: + cert: PCK leaf certificate from the quote + + Returns: + PckExtensions containing FMSPC, PCEID, PPID, and TCB info + + Raises: + PckExtensionError: If required extensions are missing or malformed + """ + if len(cert.extensions) != PCK_CERT_EXTENSION_COUNT: + raise PckExtensionError( + f"PCK certificate has {len(cert.extensions)} extensions, " + f"expected {PCK_CERT_EXTENSION_COUNT}" + ) + + sgx_ext = None + for ext in cert.extensions: + if ext.oid == OID_SGX_EXTENSION: + sgx_ext = ext + break + + if sgx_ext is None: + raise PckExtensionError( + "PCK certificate does not contain Intel SGX extension " + f"(OID {OID_SGX_EXTENSION.dotted_string})" + ) + + return _parse_sgx_extension(sgx_ext.value.value) + + +def _der_decode(data: bytes, label: str = "value"): + """Decode DER data using pyasn1, wrapping errors as PckExtensionError. + + Checks for leftover bytes after decoding (matches Go's asn1.Unmarshal behavior). + """ + try: + result, remainder = der_decoder.decode(data) + except Exception as e: + raise PckExtensionError(f"Failed to decode ASN.1 {label}: {e}") from e + if remainder: + raise PckExtensionError( + f"Unexpected leftover bytes after decoding {label}: {len(remainder)} bytes" + ) + return result + + +def _octet_hex(component, name: str, expected_size: int) -> str: + """Extract bytes from a decoded ASN.1 value, validate size, return hex.""" + raw = bytes(component) + if len(raw) != expected_size: + raise PckExtensionError( + f"{name} has wrong size: expected {expected_size}, got {len(raw)}" + ) + return raw.hex() + + +def _parse_sgx_extension(raw_value: bytes) -> PckExtensions: + """Parse the SGX extension: a SEQUENCE OF SEQUENCE(OID, value).""" + outer_seq = _der_decode(raw_value, "SGX extension") + + if len(outer_seq) < SGX_EXTENSION_MIN_SIZE: + raise PckExtensionError( + f"SGX extension has {len(outer_seq)} elements, " + f"expected at least {SGX_EXTENSION_MIN_SIZE}" + ) + + ppid = None + tcb = None + pceid = None + fmspc = None + + with _asn1_errors("SGX extension"): + for item in outer_seq: + if len(item) < 2: + raise PckExtensionError( + f"Malformed SGX extension item: expected at least 2 fields, got {len(item)}" + ) + oid_str = str(item[0]) + + if oid_str == OID_PPID.dotted_string: + ppid = _octet_hex(item[1], "PPID", PPID_SIZE) + elif oid_str == OID_TCB.dotted_string: + tcb = _parse_tcb(item[1]) + elif oid_str == OID_PCEID.dotted_string: + pceid = _octet_hex(item[1], "PCEID", PCEID_SIZE) + elif oid_str == OID_FMSPC.dotted_string: + fmspc = _octet_hex(item[1], "FMSPC", FMSPC_SIZE) + + if fmspc is None: + raise PckExtensionError("FMSPC not found in PCK certificate") + if pceid is None: + raise PckExtensionError("PCEID not found in PCK certificate") + if ppid is None: + raise PckExtensionError("PPID not found in PCK certificate") + if tcb is None: + raise PckExtensionError("TCB not found in PCK certificate") + + return PckExtensions(ppid=ppid, tcb=tcb, pceid=pceid, fmspc=fmspc) + + +def _parse_tcb(value_component) -> PckCertTCB: + """ + Parse the TCB extension: a SEQUENCE of (OID, value) pairs for + 16 TCB components, PCE SVN, and CPU SVN. + + value_component is already a decoded pyasn1 Sequence from the parent decode. + """ + tcb_seq = value_component + + if len(tcb_seq) != TCB_EXTENSION_SIZE: + raise PckExtensionError( + f"TCB extension has {len(tcb_seq)} elements, " + f"expected {TCB_EXTENSION_SIZE}" + ) + + tcb_components = [0] * TCB_COMPONENTS_COUNT + pce_svn = 0 + cpu_svn = bytes(CPU_SVN_SIZE) + + with _asn1_errors("TCB extension"): + for item in tcb_seq: + if len(item) < 2: + raise PckExtensionError( + f"Malformed TCB extension item: expected at least 2 fields, got {len(item)}" + ) + oid_str = str(item[0]) + + if oid_str == OID_PCE_SVN.dotted_string: + val = int(item[1]) + if val < 0 or val > 0xFFFF: + raise PckExtensionError( + f"PCE SVN value {val} out of uint16 range" + ) + pce_svn = val + elif oid_str == OID_CPU_SVN.dotted_string: + raw = bytes(item[1]) + if len(raw) != CPU_SVN_SIZE: + raise PckExtensionError( + f"CPU SVN has wrong size: expected {CPU_SVN_SIZE}, got {len(raw)}" + ) + cpu_svn = raw + elif (idx := _TCB_COMPONENT_OID_INDEX.get(oid_str)) is not None: + val = int(item[1]) + if val < 0 or val > 0xFF: + raise PckExtensionError( + f"TCB component {idx + 1} value {val} out of byte range" + ) + tcb_components[idx] = val + + return PckCertTCB(pce_svn=pce_svn, cpu_svn=cpu_svn, tcb_components=tcb_components) diff --git a/tests/test_pck_extensions.py b/tests/test_pck_extensions.py new file mode 100644 index 0000000..2f05627 --- /dev/null +++ b/tests/test_pck_extensions.py @@ -0,0 +1,144 @@ +""" +Unit tests for PCK certificate extension parsing. +""" + +import pytest + +from tinfoil.attestation.pck_extensions import ( + PckExtensions, + PckCertTCB, + PckExtensionError, + OID_SGX_EXTENSION, + OID_FMSPC, + OID_PCEID, + OID_PPID, + OID_TCB, + FMSPC_SIZE, + PCEID_SIZE, + PPID_SIZE, + TCB_COMPONENTS_COUNT, +) + + +# ============================================================================= +# Test OID Constants +# ============================================================================= + +class TestOidConstants: + """Test Intel SGX OID constants.""" + + def test_sgx_extension_oid(self): + """Test SGX extension OID value.""" + assert OID_SGX_EXTENSION.dotted_string == "1.2.840.113741.1.13.1" + + def test_fmspc_oid(self): + """Test FMSPC OID value.""" + assert OID_FMSPC.dotted_string == "1.2.840.113741.1.13.1.4" + + def test_pceid_oid(self): + """Test PCEID OID value.""" + assert OID_PCEID.dotted_string == "1.2.840.113741.1.13.1.3" + + def test_ppid_oid(self): + """Test PPID OID value.""" + assert OID_PPID.dotted_string == "1.2.840.113741.1.13.1.1" + + def test_tcb_oid(self): + """Test TCB OID value.""" + assert OID_TCB.dotted_string == "1.2.840.113741.1.13.1.2" + + +# ============================================================================= +# Test Size Constants +# ============================================================================= + +class TestSizeConstants: + """Test size constants.""" + + def test_fmspc_size(self): + """Test FMSPC is 6 bytes.""" + assert FMSPC_SIZE == 6 + + def test_pceid_size(self): + """Test PCEID is 2 bytes.""" + assert PCEID_SIZE == 2 + + def test_ppid_size(self): + """Test PPID is 16 bytes.""" + assert PPID_SIZE == 16 + + def test_tcb_components_count(self): + """Test TCB has 16 components.""" + assert TCB_COMPONENTS_COUNT == 16 + + +# ============================================================================= +# Test Data Classes +# ============================================================================= + +class TestPckCertTCB: + """Test PckCertTCB dataclass.""" + + def test_create_tcb(self): + """Test creating PckCertTCB.""" + tcb = PckCertTCB( + pce_svn=13, + cpu_svn=bytes(16), + tcb_components=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ) + assert tcb.pce_svn == 13 + assert len(tcb.cpu_svn) == 16 + assert len(tcb.tcb_components) == 16 + + def test_tcb_str(self): + """Test string representation.""" + tcb = PckCertTCB( + pce_svn=13, + cpu_svn=bytes([0xAB] * 16), + tcb_components=[0] * 16, + ) + s = str(tcb) + assert "pce_svn=13" in s + assert "abab" in s.lower() + + +class TestPckExtensions: + """Test PckExtensions dataclass.""" + + def test_create_extensions(self): + """Test creating PckExtensions.""" + tcb = PckCertTCB( + pce_svn=13, + cpu_svn=bytes(16), + tcb_components=[0] * 16, + ) + ext = PckExtensions( + ppid="00112233445566778899aabbccddeeff", + tcb=tcb, + pceid="0000", + fmspc="00606a000000", + ) + assert ext.fmspc == "00606a000000" + assert ext.pceid == "0000" + assert len(ext.ppid) == 32 # 16 bytes hex + + def test_extensions_str(self): + """Test string representation.""" + tcb = PckCertTCB( + pce_svn=13, + cpu_svn=bytes(16), + tcb_components=[0] * 16, + ) + ext = PckExtensions( + ppid="00112233445566778899aabbccddeeff", + tcb=tcb, + pceid="0000", + fmspc="00606a000000", + ) + s = str(ext) + assert "fmspc=00606a000000" in s + assert "pceid=0000" in s + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From f43326eeb70056a3d0af566fb39ec11d2304aa74 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:14:53 -0500 Subject: [PATCH 06/19] feat: add certificate chain utilities (cert_utils.py, intel_root_ca.py) Implement PEM chain parsing via cryptography's load_pem_x509_certificates, Intel SGX Root CA chain verification with public key pinning, validity period checks, BasicConstraints CA=True enforcement, and chain signature verification. Embeds the Intel SGX Root CA certificate. --- src/tinfoil/attestation/cert_utils.py | 152 +++++++++++++++++++++++ src/tinfoil/attestation/intel_root_ca.py | 80 ++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 src/tinfoil/attestation/cert_utils.py create mode 100644 src/tinfoil/attestation/intel_root_ca.py diff --git a/src/tinfoil/attestation/cert_utils.py b/src/tinfoil/attestation/cert_utils.py new file mode 100644 index 0000000..911e673 --- /dev/null +++ b/src/tinfoil/attestation/cert_utils.py @@ -0,0 +1,152 @@ +""" +Shared certificate utilities for TDX attestation verification. + +This module provides common certificate parsing and chain verification +functions used by both verify_tdx.py and collateral_tdx.py. +""" + +from datetime import datetime, timezone +from typing import List + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization + +from .intel_root_ca import get_intel_root_ca + +class CertificateChainError(Exception): + """Raised when certificate chain verification fails.""" + pass + + +def parse_pem_chain(pem_data: bytes) -> List[x509.Certificate]: + """ + Parse concatenated PEM certificates. + + Handles: + - Concatenated PEM certificates + - Leading/trailing whitespace and null bytes + - Trailing null bytes between certificates (common in TDX quotes) + + Args: + pem_data: PEM-encoded certificate chain (bytes) + + Returns: + List of parsed certificates in order + + Raises: + CertificateChainError: If parsing fails + """ + try: + return x509.load_pem_x509_certificates(pem_data) + except Exception as e: + raise CertificateChainError(f"Failed to parse PEM certificate chain: {e}") from e + + +def certs_to_pem(certs: List[x509.Certificate]) -> str: + """ + Convert list of certificates to concatenated PEM string. + + Args: + certs: List of certificates + + Returns: + Concatenated PEM string + """ + pem_parts = [] + for cert in certs: + pem_parts.append(cert.public_bytes(serialization.Encoding.PEM).decode("ascii")) + return "".join(pem_parts) + + +def verify_intel_chain( + certs: List[x509.Certificate], + chain_name: str = "Certificate chain", +) -> None: + """ + Verify a certificate chain against Intel SGX Root CA. + + Performs manual chain verification without requiring TLS extensions (SAN). + Intel PCK certificates don't have SAN extension, so we can't use + PolicyBuilder.build_client_verifier(). + + Verification steps: + 1. Verify root cert matches embedded Intel SGX Root CA (by public key) + 2. Verify each certificate's validity period + 3. Verify issuer certificates have BasicConstraints CA=True + 4. Verify each certificate was issued by the next cert in chain + + Args: + certs: Certificate chain [leaf, intermediate(s)..., root] + chain_name: Human-readable name for error messages + + Raises: + CertificateChainError: If chain verification fails + """ + if len(certs) < 2: + raise CertificateChainError( + f"{chain_name} must contain at least 2 certificates (leaf and root)" + ) + + intel_root = get_intel_root_ca() + + # Step 1: Verify root certificate matches Intel SGX Root CA by public key + chain_root = certs[-1] + chain_root_pubkey = chain_root.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + intel_root_pubkey = intel_root.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + if chain_root_pubkey != intel_root_pubkey: + raise CertificateChainError( + f"{chain_name} root certificate does not match Intel SGX Root CA" + ) + + # Step 2: Verify each certificate's validity period + now = datetime.now(timezone.utc) + for cert in certs: + if now < cert.not_valid_before_utc: + raise CertificateChainError( + f"{chain_name}: certificate not yet valid (not before {cert.not_valid_before_utc})" + ) + if now > cert.not_valid_after_utc: + raise CertificateChainError( + f"{chain_name}: certificate expired (not after {cert.not_valid_after_utc})" + ) + + # Step 3: Verify issuer certificates are CAs (BasicConstraints) + # Go's x509.Verify checks this implicitly; we must do it explicitly. + for cert in certs[1:]: + try: + bc = cert.extensions.get_extension_for_class(x509.BasicConstraints) + if not bc.value.ca: + raise CertificateChainError( + f"{chain_name}: intermediate/root certificate is not a CA" + ) + except x509.ExtensionNotFound: + raise CertificateChainError( + f"{chain_name}: intermediate/root certificate missing BasicConstraints extension" + ) + + # Step 4: Verify certificate chain signatures using verify_directly_issued_by + # Each cert[i] should be signed by cert[i+1] + for i in range(len(certs) - 1): + cert = certs[i] + issuer = certs[i + 1] + try: + cert.verify_directly_issued_by(issuer) + except Exception as e: + raise CertificateChainError( + f"{chain_name}: certificate chain signature verification failed: {e}" + ) + + # Step 5: Verify the chain's root is signed by the trusted Intel root + try: + chain_root.verify_directly_issued_by(intel_root) + except Exception as e: + raise CertificateChainError( + f"{chain_name}: root certificate verification against Intel SGX Root CA failed: {e}" + ) diff --git a/src/tinfoil/attestation/intel_root_ca.py b/src/tinfoil/attestation/intel_root_ca.py new file mode 100644 index 0000000..bd31a36 --- /dev/null +++ b/src/tinfoil/attestation/intel_root_ca.py @@ -0,0 +1,80 @@ +""" +Intel SGX Root CA certificate for TDX attestation verification. + +This module provides the embedded Intel SGX Provisioning Certification Root CA +certificate, which is the trust anchor for verifying TDX attestation quotes. + +Certificate chain for TDX: + Intel SGX Root CA (this file - trust anchor) + └─► Intel SGX PCK Platform CA (intermediate, from collateral) + └─► Intel SGX PCK Certificate (leaf, embedded in quote) + +Source: https://certificates.trustedservices.intel.com/Intel_SGX_Provisioning_Certification_RootCA.pem +""" + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization + +# Intel SGX Provisioning Certification Root CA +# Subject: CN=Intel SGX Root CA, O=Intel Corporation, L=Santa Clara, ST=CA, C=US +# Valid: 2018-05-21 to 2049-12-31 +# Key: ECDSA P-256 +INTEL_SGX_ROOT_CA_PEM = b"""-----BEGIN CERTIFICATE----- +MIICjzCCAjSgAwIBAgIUImUM1lqdNInzg7SVUr9QGzknBqwwCgYIKoZIzj0EAwIw +aDEaMBgGA1UEAwwRSW50ZWwgU0dYIFJvb3QgQ0ExGjAYBgNVBAoMEUludGVsIENv +cnBvcmF0aW9uMRQwEgYDVQQHDAtTYW50YSBDbGFyYTELMAkGA1UECAwCQ0ExCzAJ +BgNVBAYTAlVTMB4XDTE4MDUyMTEwNDUxMFoXDTQ5MTIzMTIzNTk1OVowaDEaMBgG +A1UEAwwRSW50ZWwgU0dYIFJvb3QgQ0ExGjAYBgNVBAoMEUludGVsIENvcnBvcmF0 +aW9uMRQwEgYDVQQHDAtTYW50YSBDbGFyYTELMAkGA1UECAwCQ0ExCzAJBgNVBAYT +AlVTMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEC6nEwMDIYZOj/iPWsCzaEKi7 +1OiOSLRFhWGjbnBVJfVnkY4u3IjkDYYL0MxO4mqsyYjlBalTVYxFP2sJBK5zlKOB +uzCBuDAfBgNVHSMEGDAWgBQiZQzWWp00ifODtJVSv1AbOScGrDBSBgNVHR8ESzBJ +MEegRaBDhkFodHRwczovL2NlcnRpZmljYXRlcy50cnVzdGVkc2VydmljZXMuaW50 +ZWwuY29tL0ludGVsU0dYUm9vdENBLmRlcjAdBgNVHQ4EFgQUImUM1lqdNInzg7SV +Ur9QGzknBqwwDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQEwCgYI +KoZIzj0EAwIDSQAwRgIhAOW/5QkR+S9CiSDcNoowLuPRLsWGf/Yi7GSX94BgwTwg +AiEA4J0lrHoMs+Xo5o/sX6O9QWxHRAvZUGOdRQ7cvqRXaqI= +-----END CERTIFICATE----- +""" + + +def get_intel_root_ca() -> x509.Certificate: + """ + Load and return the Intel SGX Root CA certificate. + + Returns: + Parsed X.509 certificate object + + Example: + >>> root_ca = get_intel_root_ca() + >>> print(root_ca.subject) + + """ + return x509.load_pem_x509_certificate(INTEL_SGX_ROOT_CA_PEM) + + +def get_intel_root_ca_public_key_der() -> bytes: + """ + Get the Intel SGX Root CA public key in DER format. + + This is useful for comparing against certificate chain anchors. + + Returns: + DER-encoded public key bytes + """ + cert = get_intel_root_ca() + return cert.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + +def get_intel_root_ca_der() -> bytes: + """ + Get the Intel SGX Root CA certificate in DER format. + + Returns: + DER-encoded certificate bytes + """ + cert = get_intel_root_ca() + return cert.public_bytes(serialization.Encoding.DER) From 4105a61fbc52d592b5d81f199e9ab5184d4c4cb1 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:14:59 -0500 Subject: [PATCH 07/19] feat: add TDX attestation orchestration (attestation_tdx.py) Implement end-to-end TDX attestation flow: quote parsing, cryptographic verification, policy validation, and collateral validation. Adds TdxVerificationConfig for policy injection, verify_tdx_hardware for Sigstore-based measurement comparison, and ACCEPTED_MR_SEAMS whitelist. --- src/tinfoil/attestation/attestation_tdx.py | 329 ++++++++ tests/test_tdx_attestation_flow.py | 93 +++ tests/test_tdx_multiplatform.py | 710 ++++++++++++++++++ tests/test_verification_failures.py | 298 ++++++++ .../test_verification_failures_integration.py | 143 ++++ 5 files changed, 1573 insertions(+) create mode 100644 src/tinfoil/attestation/attestation_tdx.py create mode 100644 tests/test_tdx_attestation_flow.py create mode 100644 tests/test_tdx_multiplatform.py create mode 100644 tests/test_verification_failures.py create mode 100644 tests/test_verification_failures_integration.py diff --git a/src/tinfoil/attestation/attestation_tdx.py b/src/tinfoil/attestation/attestation_tdx.py new file mode 100644 index 0000000..33817a7 --- /dev/null +++ b/src/tinfoil/attestation/attestation_tdx.py @@ -0,0 +1,329 @@ +""" +TDX Attestation Orchestration Module. + +This module provides the high-level entry point for TDX attestation verification. +It coordinates between: +- Quote parsing (abi_tdx) +- Cryptographic verification (verify_tdx) +- Policy validation (validate_tdx) +- Collateral validation (collateral_tdx) + +Usage: + from tinfoil.attestation.attestation_tdx import verify_tdx_attestation + + result = verify_tdx_attestation(attestation_doc) + # result.measurements contains the 5 TDX measurements + # result.tls_key_fp contains the TLS key fingerprint +""" + +import base64 +from dataclasses import dataclass, field +from typing import Optional + +from .abi_tdx import ( + parse_quote, + QuoteV4, + TdxQuoteParseError, + INTEL_QE_VENDOR_ID, + MR_CONFIG_ID_SIZE, + MR_OWNER_SIZE, + MR_OWNER_CONFIG_SIZE, +) +from .types import TLS_KEY_FP_SIZE, HPKE_KEY_SIZE, Measurement, Verification, PredicateType, HardwareMeasurement, HardwareMeasurementError, TDX_MRTD_IDX, TDX_RTMR0_IDX, TDX_REGISTER_COUNT, safe_gzip_decompress +from .verify_tdx import ( + verify_tdx_quote, + TdxVerificationError, + PCKCertificateChain, +) +from .validate_tdx import ( + validate_tdx_policy, + PolicyOptions, + TdQuoteBodyOptions, + HeaderOptions, + TdxValidationError, +) +from .pck_extensions import PckExtensions +from .collateral_tdx import ( + validate_collateral, + CollateralError, + TdxCollateral, + TcbLevel, +) + + +# ============================================================================= +# Orchestration Constants +# ============================================================================= + +# Minimum required tcbEvaluationDataNumber. +# This prevents accepting collateral issued before critical security updates. +# See: https://www.intel.com/content/www/us/en/developer/topic-technology/software-security-guidance/trusted-computing-base-recovery-attestation.html +# +# This value should be set using: +# from tinfoil.attestation.collateral_tdx import calculate_min_tcb_evaluation_data_number +# min_num = calculate_min_tcb_evaluation_data_number() +# +# The function queries Intel PCS and returns the lowest tcbEvaluationDataNumber +# whose TCB recovery event date is within the last year. +# +# Current value 18 corresponds to TCB recovery event date 2024-11-12. +DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER = 18 + +# Expected values for policy validation +# TdAttributes: All zeros except SEPT_VE_DISABLE=1 +EXPECTED_TD_ATTRIBUTES = bytes.fromhex("0000001000000000") +# XFam: Enable FP, SSE, AVX, AVX512, PK, AMX +EXPECTED_XFAM = bytes.fromhex("e702060000000000") +# MinimumTeeTcbSvn: 3.1.2 +EXPECTED_MINIMUM_TEE_TCB_SVN = bytes.fromhex("03010200000000000000000000000000") + +# Accepted MR_SEAM values from Intel TDX module releases +# https://github.com/intel/confidential-computing.tdx.tdx-module/releases +ACCEPTED_MR_SEAMS: tuple[bytes, ...] = ( + bytes.fromhex("476a2997c62bccc78370913d0a80b956e3721b24272bc66c4d6307ced4be2865c40e26afac75f12df3425b03eb59ea7c"), # TDX Module 2.0.08 + bytes.fromhex("7bf063280e94fb051f5dd7b1fc59ce9aac42bb961df8d44b709c9b0ff87a7b4df648657ba6d1189589feab1d5a3c9a9d"), # TDX Module 1.5.16 + bytes.fromhex("685f891ea5c20e8fa27b151bf34bf3b50fbaf7143cc53662727cbdb167c0ad8385f1f6f3571539a91e104a1c96d75e04"), # TDX Module 2.0.02 + bytes.fromhex("49b66faa451d19ebbdbe89371b8daf2b65aa3984ec90110343e9e2eec116af08850fa20e3b1aa9a874d77a65380ee7e6"), # TDX Module 1.5.08 +) + + +# ============================================================================= +# Orchestration Config +# ============================================================================= + +@dataclass(frozen=True) +class TdxVerificationConfig: + """ + Configuration for TDX attestation verification. + + All fields have sensible defaults matching the current hardcoded values. + Override individual fields to customise verification policy. + """ + min_tcb_evaluation_data_number: int = DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER + accepted_mr_seams: tuple[bytes, ...] = ACCEPTED_MR_SEAMS + policy_options: Optional[PolicyOptions] = None + expected_td_attributes: bytes = EXPECTED_TD_ATTRIBUTES + expected_xfam: bytes = EXPECTED_XFAM + expected_minimum_tee_tcb_svn: bytes = EXPECTED_MINIMUM_TEE_TCB_SVN + + +_DEFAULT_CONFIG = TdxVerificationConfig() + + +# ============================================================================= +# Orchestration Error Type +# ============================================================================= + +class TdxAttestationError(Exception): + """ + Raised when TDX attestation verification fails. + """ + pass + + +# ============================================================================= +# Orchestration Result Type +# ============================================================================= + +@dataclass +class TdxAttestationResult: + """ + Result of TDX attestation verification. + + Contains the verified quote, measurements, and TCB status. + """ + quote: QuoteV4 + pck_chain: PCKCertificateChain + pck_extensions: PckExtensions + collateral: TdxCollateral + tcb_level: TcbLevel + measurements: list[str] # [MRTD, RTMR0, RTMR1, RTMR2, RTMR3] + tls_key_fp: str + hpke_public_key: str + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + +def verify_tdx_attestation( + attestation_doc: str, + is_compressed: bool = True, + config: Optional[TdxVerificationConfig] = None, +) -> TdxAttestationResult: + """ + Verify a TDX attestation document. + + This is the main entry point for TDX attestation verification. + It performs the complete verification flow: + + 1. Decode and decompress the attestation document + 2. Parse the TDX quote + 3. Verify cryptographic signatures (PCK chain, quote, QE report) + 4. Policy validation (XFAM, TD_ATTRIBUTES, SEAM, MR_SEAM whitelist) + 5. Fetch and validate collateral (PCK extensions, TCB status, QE identity, revocation) + 6. Extract measurements and report data + + Args: + attestation_doc: Base64-encoded attestation document + is_compressed: Whether the document is gzip compressed + config: Optional verification config. Uses sensible defaults if None. + + Returns: + TdxAttestationResult containing verified data + + Raises: + TdxAttestationError: If any verification step fails + """ + if config is None: + config = _DEFAULT_CONFIG + # Step 1: Decode the attestation document + try: + raw_bytes = base64.b64decode(attestation_doc) + except Exception as e: + raise TdxAttestationError(f"Failed to decode base64: {e}") from e + + if is_compressed: + try: + raw_bytes = safe_gzip_decompress(raw_bytes) + except Exception as e: + raise TdxAttestationError(f"Failed to decompress: {e}") from e + + # Step 2: Parse the TDX quote + try: + quote = parse_quote(raw_bytes) + except TdxQuoteParseError as e: + raise TdxAttestationError(f"Failed to parse TDX quote: {e}") from e + + # Step 3: Verify cryptographic signatures + try: + pck_chain = verify_tdx_quote(quote, raw_bytes) + except TdxVerificationError as e: + raise TdxAttestationError(f"TDX quote verification failed: {e}") from e + + # Step 4: Policy validation (mirrors Go's validate.TdxQuote) + if config.policy_options is not None: + policy_options = config.policy_options + else: + policy_options = PolicyOptions( + header=HeaderOptions( + qe_vendor_id=INTEL_QE_VENDOR_ID, + ), + td_quote_body=TdQuoteBodyOptions( + minimum_tee_tcb_svn=config.expected_minimum_tee_tcb_svn, + td_attributes=config.expected_td_attributes, + xfam=config.expected_xfam, + mr_config_id=b'\x00' * MR_CONFIG_ID_SIZE, + mr_owner=b'\x00' * MR_OWNER_SIZE, + mr_owner_config=b'\x00' * MR_OWNER_CONFIG_SIZE, + any_mr_seam=config.accepted_mr_seams, + ), + ) + + try: + validate_tdx_policy(quote, policy_options) + except TdxValidationError as e: + raise TdxAttestationError(f"Policy validation failed: {e}") from e + + # Step 5: Validate collateral (PCK extensions, TCB status, QE identity, revocation) + try: + collateral_result = validate_collateral( + quote=quote, + pck_chain=pck_chain, + min_tcb_evaluation_data_number=config.min_tcb_evaluation_data_number, + ) + except CollateralError as e: + raise TdxAttestationError(f"Collateral validation failed: {e}") from e + + # Step 6: Extract measurements and report data + measurements = quote.td_quote_body.get_measurements() + measurements_hex = [m.hex() for m in measurements] + + report_data = quote.td_quote_body.report_data + required_len = TLS_KEY_FP_SIZE + HPKE_KEY_SIZE + if len(report_data) < required_len: + raise TdxAttestationError( + f"report_data too short: {len(report_data)} bytes, need at least {required_len}" + ) + tls_key_fp = report_data[0:TLS_KEY_FP_SIZE].hex() + hpke_public_key = report_data[TLS_KEY_FP_SIZE:TLS_KEY_FP_SIZE + HPKE_KEY_SIZE].hex() + + return TdxAttestationResult( + quote=quote, + pck_chain=pck_chain, + pck_extensions=collateral_result.pck_extensions, + collateral=collateral_result.collateral, + tcb_level=collateral_result.tcb_level, + measurements=measurements_hex, + tls_key_fp=tls_key_fp, + hpke_public_key=hpke_public_key, + ) + + +def verify_tdx_attestation_v2(attestation_doc: str) -> Verification: + """ + Verify TDX attestation document (v2 format) and return verification result. + + v2 format: report_data contains TLS key fingerprint (32 bytes) + HPKE public key (32 bytes). + + Args: + attestation_doc: Base64-encoded, gzip-compressed TDX quote + + Returns: + Verification containing measurements, public key fingerprint, and HPKE public key + + Raises: + ValueError: If verification fails + """ + try: + result = verify_tdx_attestation(attestation_doc, is_compressed=True) + except TdxAttestationError as e: + raise ValueError(f"TDX attestation verification failed: {e}") from e + + measurement = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=result.measurements, + ) + + return Verification( + measurement=measurement, + public_key_fp=result.tls_key_fp, + hpke_public_key=result.hpke_public_key, + ) + + +def verify_tdx_hardware( + hardware_measurements: list[HardwareMeasurement], + enclave_measurement: Measurement, +) -> HardwareMeasurement: + """ + Verify that the enclave's MRTD and RTMR0 match a known hardware platform. + + Args: + hardware_measurements: List of known-good hardware measurements from Sigstore + enclave_measurement: The measurement from the TDX enclave attestation + + Returns: + The matching HardwareMeasurement + + Raises: + HardwareMeasurementError: If no matching hardware platform is found + ValueError: If enclave measurement is invalid + """ + if enclave_measurement is None: + raise ValueError("enclave measurement is None") + + if enclave_measurement.type != PredicateType.TDX_GUEST_V2: + raise ValueError(f"unsupported enclave platform: {enclave_measurement.type}") + + if len(enclave_measurement.registers) != TDX_REGISTER_COUNT: + raise ValueError(f"expected {TDX_REGISTER_COUNT} TDX registers, got {len(enclave_measurement.registers)}") + + enclave_mrtd = enclave_measurement.registers[TDX_MRTD_IDX] + enclave_rtmr0 = enclave_measurement.registers[TDX_RTMR0_IDX] + + for hw in hardware_measurements: + if hw.mrtd == enclave_mrtd and hw.rtmr0 == enclave_rtmr0: + return hw + + raise HardwareMeasurementError("no matching hardware platform found") diff --git a/tests/test_tdx_attestation_flow.py b/tests/test_tdx_attestation_flow.py new file mode 100644 index 0000000..4df9ff4 --- /dev/null +++ b/tests/test_tdx_attestation_flow.py @@ -0,0 +1,93 @@ +""" +Integration test for TDX attestation verification flow. + +Tests the complete verification using the SecureClient API. + +Configure via environment variables: + TINFOIL_TEST_REPO - GitHub repo (e.g., tinfoilsh/confidential-gpt-oss-120b-free) + TINFOIL_TEST_ENCLAVE - Enclave hostname (e.g., gpt-oss-120b-free.inf5.tinfoil.sh) + +Example: + TINFOIL_TEST_REPO=tinfoilsh/confidential-deepseek-r1-0528 \ + TINFOIL_TEST_ENCLAVE=deepseek-r1-0528.inf9.tinfoil.sh \ + python -m pytest tests/test_tdx_attestation_flow.py -v -s +""" + +import os +import pytest + +from tinfoil.client import SecureClient +from tinfoil.attestation import PredicateType, TDX_TYPES + +pytestmark = pytest.mark.integration # allows pytest -m integration filtering + +# Test configuration from environment or defaults +REPO = os.environ.get("TINFOIL_TEST_REPO", "tinfoilsh/confidential-gpt-oss-120b-free") +ENCLAVE = os.environ.get("TINFOIL_TEST_ENCLAVE", "gpt-oss-120b-free.inf5.tinfoil.sh") + + +def test_tdx_full_verification_flow(): + """ + Tests the complete TDX attestation verification flow using SecureClient. + + SecureClient.verify() performs: + 1. Fetch runtime attestation from enclave + 2. Verify attestation (cryptographic + policy validation) + 3. Fetch latest digest from GitHub + 4. Fetch and verify sigstore attestation bundle + 5. For TDX: verify hardware measurements (MRTD, RTMR0) + 6. Compare code measurements with runtime measurements + """ + print(f"\nVerifying TDX enclave: {ENCLAVE}") + print(f"Against repo: {REPO}") + + client = SecureClient(enclave=ENCLAVE, repo=REPO) + ground_truth = client.verify() + + # Check this is actually TDX + if ground_truth.measurement.type not in TDX_TYPES: + pytest.skip( + f"Enclave returned {ground_truth.measurement.type}, not TDX. " + "This test is specifically for TDX enclaves." + ) + + print(f"\n✓ TDX Verification successful!") + print(f" Measurement type: {ground_truth.measurement.type}") + print(f" Measurement fingerprint: {ground_truth.measurement.fingerprint()}") + print(f" Public key fingerprint: {ground_truth.public_key}") + print(f" Digest: {ground_truth.digest}") + + # Print TDX-specific measurements + regs = ground_truth.measurement.registers + print("\n TDX Measurements:") + print(f" MRTD: {regs[0][:32]}...") + print(f" RTMR0: {regs[1][:32]}...") + print(f" RTMR1: {regs[2][:32]}...") + print(f" RTMR2: {regs[3][:32]}...") + print(f" RTMR3: {regs[4][:32]}...") + + +def test_tdx_secure_http_client(): + """ + Tests that SecureClient creates a working pinned HTTP client for TDX enclaves. + """ + print(f"\nCreating secure HTTP client for: {ENCLAVE}") + + client = SecureClient(enclave=ENCLAVE, repo=REPO) + http_client = client.make_secure_http_client() + + ground_truth = client.ground_truth + assert ground_truth is not None + + if ground_truth.measurement.type not in TDX_TYPES: + http_client.close() + pytest.skip(f"Enclave returned {ground_truth.measurement.type}, not TDX.") + + print(f"\n✓ Secure HTTP client created successfully!") + print(f" TLS pinned to: {ground_truth.public_key}") + + http_client.close() + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_tdx_multiplatform.py b/tests/test_tdx_multiplatform.py new file mode 100644 index 0000000..a87dd46 --- /dev/null +++ b/tests/test_tdx_multiplatform.py @@ -0,0 +1,710 @@ +""" +Unit tests for TDX multiplatform measurement handling. + +Covers: +1. Sigstore multiplatform TDX measurement parsing (SNP_TDX_MULTIPLATFORM_v1) +2. RTMR3-zero enforcement in measurement comparison +3. Module-identity matching with real tee_tcb_svn values +""" + +import json +import pytest +from unittest.mock import patch, MagicMock +from datetime import datetime, timezone, timedelta + +from tinfoil.attestation import ( + Measurement, + PredicateType, + MeasurementMismatchError, + Rtmr3NotZeroError, + RTMR3_ZERO, +) +from tinfoil.attestation.collateral_tdx import ( + TcbInfo, + TcbLevel, + Tcb, + TcbComponent, + TcbStatus, + TdxModuleIdentity, + get_tdx_module_identity, + validate_tdx_module_identity, + CollateralError, +) + + +# ============================================================================= +# Test Data - Realistic TDX Measurements +# ============================================================================= + +# 48-byte hex strings (96 chars) for TDX measurements +SAMPLE_MRTD = "a1" * 48 +SAMPLE_RTMR0 = "b2" * 48 +SAMPLE_RTMR1 = "c3" * 48 +SAMPLE_RTMR2 = "d4" * 48 +SAMPLE_RTMR3_ZEROS = "00" * 48 +SAMPLE_RTMR3_NONZERO = "e5" * 48 +SAMPLE_SNP_MEASUREMENT = "f6" * 48 + + +# ============================================================================= +# Sigstore Multiplatform TDX Measurement Parsing Tests +# ============================================================================= + +class TestSigstoreMultiplatformParsing: + """Test parsing of SNP_TDX_MULTIPLATFORM_v1 predicates from Sigstore.""" + + def _create_mock_bundle_payload( + self, + snp_measurement: str, + rtmr1: str, + rtmr2: str, + digest: str = "abc123", + ) -> bytes: + """Create a mock in-toto payload with multiplatform predicate.""" + payload = { + "predicateType": PredicateType.SNP_TDX_MULTIPLATFORM_v1.value, + "predicate": { + "snp_measurement": snp_measurement, + "tdx_measurement": { + "rtmr1": rtmr1, + "rtmr2": rtmr2, + }, + }, + "subject": [{"digest": {"sha256": digest}}], + } + return json.dumps(payload).encode() + + def test_parse_multiplatform_extracts_all_registers(self): + """Test that multiplatform parsing extracts snp_measurement, rtmr1, rtmr2.""" + from tinfoil.sigstore import verify_attestation + + payload = self._create_mock_bundle_payload( + snp_measurement=SAMPLE_SNP_MEASUREMENT, + rtmr1=SAMPLE_RTMR1, + rtmr2=SAMPLE_RTMR2, + digest="test_digest_abc123", + ) + + with patch('tinfoil.sigstore.Verifier') as mock_verifier_cls: + mock_verifier = MagicMock() + mock_verifier_cls.production.return_value = mock_verifier + mock_verifier.verify_dsse.return_value = ( + 'application/vnd.in-toto+json', + payload, + ) + + with patch('tinfoil.sigstore.Bundle'): + result = verify_attestation( + bundle_json=b'{}', + digest="test_digest_abc123", + repo="test/repo", + ) + + assert result.type == PredicateType.SNP_TDX_MULTIPLATFORM_v1 + assert len(result.registers) == 3 + assert result.registers[0] == SAMPLE_SNP_MEASUREMENT + assert result.registers[1] == SAMPLE_RTMR1 + assert result.registers[2] == SAMPLE_RTMR2 + + def test_parse_multiplatform_missing_snp_measurement_fails(self): + """Test that missing snp_measurement raises ValueError.""" + from tinfoil.sigstore import verify_attestation + + payload = { + "predicateType": PredicateType.SNP_TDX_MULTIPLATFORM_v1.value, + "predicate": { + # snp_measurement missing + "tdx_measurement": { + "rtmr1": SAMPLE_RTMR1, + "rtmr2": SAMPLE_RTMR2, + }, + }, + "subject": [{"digest": {"sha256": "test_digest"}}], + } + payload_bytes = json.dumps(payload).encode() + + with patch('tinfoil.sigstore.Verifier') as mock_verifier_cls: + mock_verifier = MagicMock() + mock_verifier_cls.production.return_value = mock_verifier + mock_verifier.verify_dsse.return_value = ( + 'application/vnd.in-toto+json', + payload_bytes, + ) + + with patch('tinfoil.sigstore.Bundle'): + with pytest.raises(ValueError, match="no snp_measurement"): + verify_attestation(b'{}', "test_digest", "test/repo") + + def test_parse_multiplatform_missing_tdx_measurement_fails(self): + """Test that missing tdx_measurement struct raises ValueError.""" + from tinfoil.sigstore import verify_attestation + + payload = { + "predicateType": PredicateType.SNP_TDX_MULTIPLATFORM_v1.value, + "predicate": { + "snp_measurement": SAMPLE_SNP_MEASUREMENT, + # tdx_measurement missing + }, + "subject": [{"digest": {"sha256": "test_digest"}}], + } + payload_bytes = json.dumps(payload).encode() + + with patch('tinfoil.sigstore.Verifier') as mock_verifier_cls: + mock_verifier = MagicMock() + mock_verifier_cls.production.return_value = mock_verifier + mock_verifier.verify_dsse.return_value = ( + 'application/vnd.in-toto+json', + payload_bytes, + ) + + with patch('tinfoil.sigstore.Bundle'): + with pytest.raises(ValueError, match="no tdx_measurement"): + verify_attestation(b'{}', "test_digest", "test/repo") + + def test_parse_multiplatform_missing_rtmr1_fails(self): + """Test that missing rtmr1 in tdx_measurement raises ValueError.""" + from tinfoil.sigstore import verify_attestation + + payload = { + "predicateType": PredicateType.SNP_TDX_MULTIPLATFORM_v1.value, + "predicate": { + "snp_measurement": SAMPLE_SNP_MEASUREMENT, + "tdx_measurement": { + # rtmr1 missing + "rtmr2": SAMPLE_RTMR2, + }, + }, + "subject": [{"digest": {"sha256": "test_digest"}}], + } + payload_bytes = json.dumps(payload).encode() + + with patch('tinfoil.sigstore.Verifier') as mock_verifier_cls: + mock_verifier = MagicMock() + mock_verifier_cls.production.return_value = mock_verifier + mock_verifier.verify_dsse.return_value = ( + 'application/vnd.in-toto+json', + payload_bytes, + ) + + with patch('tinfoil.sigstore.Bundle'): + with pytest.raises(ValueError, match="missing rtmr1 or rtmr2"): + verify_attestation(b'{}', "test_digest", "test/repo") + + def test_parse_multiplatform_missing_rtmr2_fails(self): + """Test that missing rtmr2 in tdx_measurement raises ValueError.""" + from tinfoil.sigstore import verify_attestation + + payload = { + "predicateType": PredicateType.SNP_TDX_MULTIPLATFORM_v1.value, + "predicate": { + "snp_measurement": SAMPLE_SNP_MEASUREMENT, + "tdx_measurement": { + "rtmr1": SAMPLE_RTMR1, + # rtmr2 missing + }, + }, + "subject": [{"digest": {"sha256": "test_digest"}}], + } + payload_bytes = json.dumps(payload).encode() + + with patch('tinfoil.sigstore.Verifier') as mock_verifier_cls: + mock_verifier = MagicMock() + mock_verifier_cls.production.return_value = mock_verifier + mock_verifier.verify_dsse.return_value = ( + 'application/vnd.in-toto+json', + payload_bytes, + ) + + with patch('tinfoil.sigstore.Bundle'): + with pytest.raises(ValueError, match="missing rtmr1 or rtmr2"): + verify_attestation(b'{}', "test_digest", "test/repo") + + +# ============================================================================= +# RTMR3-Zero Enforcement Tests +# ============================================================================= + +class TestRtmr3ZeroEnforcement: + """Test RTMR3-zero enforcement in measurement comparison.""" + + def test_rtmr3_zero_constant_is_correct(self): + """Test that RTMR3_ZERO constant is 96 hex zeros (48 bytes).""" + assert len(RTMR3_ZERO) == 96 + assert RTMR3_ZERO == "0" * 96 + + def test_multiplatform_vs_tdx_with_zero_rtmr3_passes(self): + """Test comparison passes when TDX RTMR3 is zeros.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, # MRTD (index 0) + SAMPLE_RTMR0, # RTMR0 (index 1) + SAMPLE_RTMR1, # RTMR1 (index 2) - must match + SAMPLE_RTMR2, # RTMR2 (index 3) - must match + SAMPLE_RTMR3_ZEROS, # RTMR3 (index 4) - must be zeros + ], + ) + + # Should not raise + multiplatform.assert_equal(tdx) + + def test_multiplatform_vs_tdx_with_nonzero_rtmr3_fails(self): + """Test comparison fails when TDX RTMR3 is not zeros.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, + SAMPLE_RTMR0, + SAMPLE_RTMR1, + SAMPLE_RTMR2, + SAMPLE_RTMR3_NONZERO, # RTMR3 is NOT zeros + ], + ) + + with pytest.raises(Rtmr3NotZeroError, match="RTMR3 must be zeros"): + multiplatform.assert_equal(tdx) + + def test_multiplatform_vs_tdx_rtmr1_mismatch_fails(self): + """Test comparison fails when RTMR1 doesn't match.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, + SAMPLE_RTMR0, + "wrong_rtmr1" + "0" * 80, # Wrong RTMR1 + SAMPLE_RTMR2, + SAMPLE_RTMR3_ZEROS, + ], + ) + + with pytest.raises(MeasurementMismatchError): + multiplatform.assert_equal(tdx) + + def test_multiplatform_vs_tdx_rtmr2_mismatch_fails(self): + """Test comparison fails when RTMR2 doesn't match.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, + SAMPLE_RTMR0, + SAMPLE_RTMR1, + "wrong_rtmr2" + "0" * 80, # Wrong RTMR2 + SAMPLE_RTMR3_ZEROS, + ], + ) + + with pytest.raises(MeasurementMismatchError): + multiplatform.assert_equal(tdx) + + def test_reverse_comparison_tdx_vs_multiplatform(self): + """Test reverse comparison (TDX.assert_equal(multiplatform)) also works.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, + SAMPLE_RTMR0, + SAMPLE_RTMR1, + SAMPLE_RTMR2, + SAMPLE_RTMR3_ZEROS, + ], + ) + + # Reverse comparison should also work + tdx.assert_equal(multiplatform) + + def test_reverse_comparison_with_nonzero_rtmr3_fails(self): + """Test reverse comparison also enforces RTMR3-zero.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, + SAMPLE_RTMR0, + SAMPLE_RTMR1, + SAMPLE_RTMR2, + SAMPLE_RTMR3_NONZERO, + ], + ) + + with pytest.raises(Rtmr3NotZeroError): + tdx.assert_equal(multiplatform) + + def test_multiplatform_vs_snp_compares_snp_measurement(self): + """Test multiplatform vs SNP compares the SNP measurement register.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + snp = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=[SAMPLE_SNP_MEASUREMENT], + ) + + # Should not raise - SNP measurement matches + multiplatform.assert_equal(snp) + + def test_multiplatform_vs_snp_mismatch_fails(self): + """Test multiplatform vs SNP fails when SNP measurement doesn't match.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + snp = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["wrong_snp_" + "0" * 86], # Wrong SNP measurement + ) + + with pytest.raises(MeasurementMismatchError): + multiplatform.assert_equal(snp) + + +# ============================================================================= +# Module Identity Matching with Real TEE_TCB_SVN Values +# ============================================================================= + +class TestModuleIdentityMatchingWithRealTeeTcbSvn: + """Test module-identity matching with realistic tee_tcb_svn byte patterns.""" + + def _create_tcb_info_with_modules( + self, + module_ids: list[str], + module_mrsigner: bytes = b"\xaa" * 48, + ) -> TcbInfo: + """Create TCB Info with specified module identities.""" + module_identities = [] + for module_id in module_ids: + module_identities.append( + TdxModuleIdentity( + id=module_id, + mrsigner=module_mrsigner, + attributes=b"\x00" * 8, + attributes_mask=b"\xff" * 8, + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=3, # Minimum minor version + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ], + ) + ) + + return TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + fmspc="00a06f000000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=17, + tdx_module=None, + tdx_module_identities=module_identities, + tcb_levels=[], + ) + + # --- TEE_TCB_SVN byte layout tests --- + # TEE_TCB_SVN is 16 bytes where: + # - Byte 0 (index 0) = minor SVN (used for module TCB level matching) + # - Byte 1 (index 1) = major SVN (used to derive module ID: TDX_{major:02d}) + # - Bytes 2-15 = other TCB components + + def test_tee_tcb_svn_major_version_3_matches_tdx_03(self): + """Test TEE_TCB_SVN with major=3 matches TDX_03 module.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_01", "TDX_03", "TDX_05"]) + + # TEE_TCB_SVN: minor=5, major=3 (TDX module version 3) + tee_tcb_svn = bytes([5, 3] + [0] * 14) + + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_03" + + def test_tee_tcb_svn_major_version_1_matches_tdx_01(self): + """Test TEE_TCB_SVN with major=1 matches TDX_01 module.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_01", "TDX_03"]) + + # TEE_TCB_SVN: minor=2, major=1 (TDX module version 1) + tee_tcb_svn = bytes([2, 1] + [0] * 14) + + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_01" + + def test_tee_tcb_svn_major_version_5_matches_tdx_05(self): + """Test TEE_TCB_SVN with major=5 matches TDX_05 module.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_03", "TDX_05"]) + + # TEE_TCB_SVN: minor=0, major=5 (TDX module version 5) + tee_tcb_svn = bytes([0, 5] + [0] * 14) + + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_05" + + def test_tee_tcb_svn_unknown_major_version_raises(self): + """Test TEE_TCB_SVN with unknown major version raises error.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_01", "TDX_03"]) + + # TEE_TCB_SVN: minor=0, major=99 (no TDX_99 module exists) + tee_tcb_svn = bytes([0, 99] + [0] * 14) + + with pytest.raises(CollateralError, match="Unknown TDX module version TDX_99"): + get_tdx_module_identity(tcb_info, tee_tcb_svn) + + def test_tee_tcb_svn_major_version_0_matches_tdx_00(self): + """Test TEE_TCB_SVN with major=0 matches TDX_00 if present.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_00", "TDX_01"]) + + # TEE_TCB_SVN: minor=1, major=0 (TDX module version 0) + tee_tcb_svn = bytes([1, 0] + [0] * 14) + + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_00" + + def test_tee_tcb_svn_too_short_raises(self): + """Test TEE_TCB_SVN with < 2 bytes raises error.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_03"]) + + # Only 1 byte - not enough to extract major version + tee_tcb_svn = bytes([5]) + + with pytest.raises(CollateralError, match="TEE_TCB_SVN is too short"): + get_tdx_module_identity(tcb_info, tee_tcb_svn) + + def test_validate_module_identity_with_real_svn_values(self): + """Test full module identity validation with realistic SVN values.""" + module_mrsigner = b"\xaa" * 48 + tcb_info = self._create_tcb_info_with_modules( + ["TDX_03"], + module_mrsigner=module_mrsigner, + ) + + # Realistic TEE_TCB_SVN: minor=5, major=3 + # This should match TDX_03 and pass TCB level check (minor >= 3) + tee_tcb_svn = bytes([5, 3] + [0] * 14) + mr_signer_seam = module_mrsigner + seam_attributes = b"\x00" * 8 + + # Should not raise - validation succeeds + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_minor_svn_too_low_raises_error(self): + """Test that minor SVN below threshold raises CollateralError.""" + module_mrsigner = b"\xaa" * 48 + + tcb_info = TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + fmspc="00a06f000000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=17, + tdx_module=None, + tdx_module_identities=[ + TdxModuleIdentity( + id="TDX_03", + mrsigner=module_mrsigner, + attributes=b"\x00" * 8, + attributes_mask=b"\xff" * 8, + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=5, # Requires minor >= 5 + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ], + ) + ], + tcb_levels=[], + ) + + # TEE_TCB_SVN: minor=2, major=3 - minor too low (< 5) + tee_tcb_svn = bytes([2, 3] + [0] * 14) + mr_signer_seam = module_mrsigner + seam_attributes = b"\x00" * 8 + + # No matching level found raises CollateralError + with pytest.raises(CollateralError, match="Could not find a TDX Module Identity TCB Level"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_mrsigner_mismatch_fails(self): + """Test that MR_SIGNER_SEAM mismatch raises error.""" + expected_mrsigner = b"\xaa" * 48 + tcb_info = self._create_tcb_info_with_modules( + ["TDX_03"], + module_mrsigner=expected_mrsigner, + ) + + # TEE_TCB_SVN: minor=5, major=3 + tee_tcb_svn = bytes([5, 3] + [0] * 14) + mr_signer_seam = b"\xbb" * 48 # Wrong! + seam_attributes = b"\x00" * 8 + + with pytest.raises(CollateralError, match="MR_SIGNER_SEAM does not match"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_attributes_mismatch_fails(self): + """Test that SEAM_ATTRIBUTES mismatch under mask raises error.""" + module_mrsigner = b"\xaa" * 48 + + tcb_info = TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + fmspc="00a06f000000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=17, + tdx_module=None, + tdx_module_identities=[ + TdxModuleIdentity( + id="TDX_03", + mrsigner=module_mrsigner, + attributes=b"\x01" * 8, # Expected: all 0x01 + attributes_mask=b"\xff" * 8, # Full mask + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=3, + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ], + ) + ], + tcb_levels=[], + ) + + tee_tcb_svn = bytes([5, 3] + [0] * 14) + mr_signer_seam = module_mrsigner + seam_attributes = b"\x00" * 8 # Wrong! Expected 0x01 + + with pytest.raises(CollateralError, match="SEAM_ATTRIBUTES do not match"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_revoked_status_fails(self): + """Test that REVOKED module TCB status raises error.""" + module_mrsigner = b"\xaa" * 48 + + tcb_info = TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + fmspc="00a06f000000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=17, + tdx_module=None, + tdx_module_identities=[ + TdxModuleIdentity( + id="TDX_03", + mrsigner=module_mrsigner, + attributes=b"\x00" * 8, + attributes_mask=b"\xff" * 8, + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=3, + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.REVOKED, # REVOKED! + advisory_ids=[], + ), + ], + ) + ], + tcb_levels=[], + ) + + tee_tcb_svn = bytes([5, 3] + [0] * 14) + mr_signer_seam = module_mrsigner + seam_attributes = b"\x00" * 8 + + with pytest.raises(CollateralError, match="REVOKED"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_real_world_tee_tcb_svn_pattern(self): + """Test with a realistic TEE_TCB_SVN pattern from production.""" + # Real-world example: TDX module 3.x with minor version 5 + # and non-zero values in other positions + tcb_info = self._create_tcb_info_with_modules(["TDX_03"]) + + # Realistic 16-byte TEE_TCB_SVN: + # [0]=5 (minor), [1]=3 (major), [2]=0, [3]=0, ... + # Some positions might have non-zero values for platform TCB + tee_tcb_svn = bytes([5, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_03" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_verification_failures.py b/tests/test_verification_failures.py new file mode 100644 index 0000000..0cc4f18 --- /dev/null +++ b/tests/test_verification_failures.py @@ -0,0 +1,298 @@ +""" +Tests to ensure verification failures are properly propagated. + +These tests ensure that if any verification step fails: +1. The error is raised, not silently ignored +2. No HTTP client is created +3. No connection to the enclave is made + +This guards against bugs where failed verification would still allow +connections to proceed. +""" + +import pytest +from unittest.mock import patch, MagicMock + +from tinfoil.client import SecureClient +from tinfoil.attestation import ( + Measurement, + PredicateType, + MeasurementMismatchError, + HardwareMeasurementError, + verify_tdx_hardware, + HardwareMeasurement, +) + + +class TestMeasurementMismatch: + """Tests that measurement mismatches raise errors and block connections.""" + + def test_measurement_equals_raises_on_mismatch(self): + """Measurement.assert_equal() must raise MeasurementMismatchError on mismatch.""" + m1 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["abc123"] + ) + m2 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["different"] + ) + + with pytest.raises(MeasurementMismatchError): + m1.assert_equal(m2) + + def test_measurement_equals_raises_on_register_count_mismatch(self): + """Measurement.assert_equal() must raise if register counts differ.""" + m1 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["abc123"] + ) + m2 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["abc123", "extra"] + ) + + with pytest.raises(MeasurementMismatchError): + m1.assert_equal(m2) + + def test_measurement_equals_passes_on_match(self): + """Measurement.assert_equal() must not raise if measurements match.""" + m1 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["abc123"] + ) + m2 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["abc123"] + ) + + # Should not raise + m1.assert_equal(m2) + + +class TestHardwareMeasurementVerification: + """Tests that hardware measurement failures raise errors.""" + + def test_verify_hardware_raises_on_no_match(self): + """verify_tdx_hardware() must raise HardwareMeasurementError if no match.""" + hardware_measurements = [ + HardwareMeasurement( + id="platform1@digest1", + mrtd="known_mrtd_1", + rtmr0="known_rtmr0_1" + ), + HardwareMeasurement( + id="platform2@digest2", + mrtd="known_mrtd_2", + rtmr0="known_rtmr0_2" + ), + ] + + enclave_measurement = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["unknown_mrtd", "unknown_rtmr0", "rtmr1", "rtmr2", "rtmr3"] + ) + + with pytest.raises(HardwareMeasurementError, match="no matching hardware platform"): + verify_tdx_hardware(hardware_measurements, enclave_measurement) + + def test_verify_hardware_passes_on_match(self): + """verify_tdx_hardware() must return the matching measurement.""" + hardware_measurements = [ + HardwareMeasurement( + id="platform1@digest1", + mrtd="known_mrtd", + rtmr0="known_rtmr0" + ), + ] + + enclave_measurement = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["known_mrtd", "known_rtmr0", "rtmr1", "rtmr2", "rtmr3"] + ) + + result = verify_tdx_hardware(hardware_measurements, enclave_measurement) + assert result.id == "platform1@digest1" + + +class TestSecureClientVerificationFailures: + """Tests that SecureClient properly blocks on verification failures.""" + + @patch('tinfoil.client.fetch_attestation') + def test_attestation_failure_blocks_verify(self, mock_fetch): + """If attestation fetch fails, verify() must raise.""" + mock_fetch.side_effect = Exception("Attestation fetch failed") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(Exception, match="Attestation fetch failed"): + client.verify() + + @patch('tinfoil.client.fetch_attestation') + def test_attestation_verification_failure_blocks_verify(self, mock_fetch): + """If attestation verification fails, verify() must raise.""" + mock_doc = MagicMock() + mock_doc.verify.side_effect = ValueError("TDX attestation verification failed") + mock_fetch.return_value = mock_doc + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(ValueError, match="TDX attestation verification failed"): + client.verify() + + @patch('tinfoil.client.fetch_attestation') + @patch('tinfoil.client.fetch_latest_digest') + @patch('tinfoil.client.fetch_attestation_bundle') + @patch('tinfoil.client.verify_attestation') + def test_measurement_mismatch_blocks_verify( + self, mock_verify_att, mock_fetch_bundle, mock_fetch_digest, mock_fetch_attestation + ): + """If code measurements don't match runtime, verify() must raise.""" + # Setup mocks + mock_fetch_digest.return_value = "test_digest" + mock_fetch_bundle.return_value = {} + + # Runtime measurement from enclave + runtime_measurement = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["runtime_measurement"] + ) + mock_verification = MagicMock() + mock_verification.measurement = runtime_measurement + mock_doc = MagicMock() + mock_doc.verify.return_value = mock_verification + mock_fetch_attestation.return_value = mock_doc + + # Code measurement from sigstore (different!) + code_measurement = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["different_code_measurement"] + ) + mock_verify_att.return_value = code_measurement + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(MeasurementMismatchError): + client.verify() + + @patch('tinfoil.client.fetch_attestation') + @patch('tinfoil.client.fetch_latest_digest') + @patch('tinfoil.client.fetch_attestation_bundle') + @patch('tinfoil.client.verify_attestation') + @patch('tinfoil.client.fetch_latest_hardware_measurements') + @patch('tinfoil.client.verify_tdx_hardware') + def test_hardware_mismatch_blocks_verify( + self, mock_verify_hw, mock_fetch_hw, mock_verify_att, + mock_fetch_bundle, mock_fetch_digest, mock_fetch_attestation + ): + """If TDX hardware measurements don't match, verify() must raise.""" + # Setup mocks + mock_fetch_digest.return_value = "test_digest" + mock_fetch_bundle.return_value = {} + + # TDX runtime measurement from enclave + runtime_measurement = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["mrtd", "rtmr0", "rtmr1", "rtmr2", "rtmr3"] + ) + mock_verification = MagicMock() + mock_verification.measurement = runtime_measurement + mock_doc = MagicMock() + mock_doc.verify.return_value = mock_verification + mock_fetch_attestation.return_value = mock_doc + + # Code measurement from sigstore + code_measurement = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["mrtd", "rtmr0", "rtmr1", "rtmr2", "rtmr3"] + ) + mock_verify_att.return_value = code_measurement + + # Hardware verification fails + mock_fetch_hw.return_value = [] + mock_verify_hw.side_effect = HardwareMeasurementError("no matching hardware platform") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(HardwareMeasurementError, match="no matching hardware platform"): + client.verify() + + @patch('tinfoil.client.fetch_attestation') + def test_http_client_not_created_on_verification_failure(self, mock_fetch): + """make_secure_http_client() must not return a client if verify() fails.""" + mock_fetch.side_effect = Exception("Attestation failed") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(Exception, match="Attestation failed"): + client.make_secure_http_client() + + # Verify ground_truth was never set + assert client.ground_truth is None + + @patch('tinfoil.client.fetch_attestation') + def test_async_http_client_not_created_on_verification_failure(self, mock_fetch): + """make_secure_async_http_client() must not return a client if verify() fails.""" + mock_fetch.side_effect = Exception("Attestation failed") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(Exception, match="Attestation failed"): + client.make_secure_async_http_client() + + # Verify ground_truth was never set + assert client.ground_truth is None + + @patch('tinfoil.client.fetch_attestation') + def test_get_http_client_calls_verify_first(self, mock_fetch): + """get_http_client() must call verify() before returning client.""" + mock_fetch.side_effect = Exception("Verification failed") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(Exception, match="Verification failed"): + client.get_http_client() + + @patch('tinfoil.client.fetch_attestation') + def test_make_request_calls_verify_first(self, mock_fetch): + """make_request() must call verify() before making request.""" + import urllib.request + + mock_fetch.side_effect = Exception("Verification failed") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + req = urllib.request.Request("https://test.enclave.sh/api") + + with pytest.raises(Exception, match="Verification failed"): + client.make_request(req) + + +class TestDirectMeasurementVerification: + """Tests for direct measurement verification (no repo).""" + + @patch('tinfoil.client.fetch_attestation') + def test_snp_measurement_mismatch_raises(self, mock_fetch): + """If SNP measurement doesn't match provided measurement, must raise.""" + # Runtime measurement from enclave + runtime_measurement = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["actual_snp_measurement"] + ) + mock_verification = MagicMock() + mock_verification.measurement = runtime_measurement + mock_doc = MagicMock() + mock_doc.verify.return_value = mock_verification + mock_fetch.return_value = mock_doc + + # Expect different measurement + client = SecureClient( + enclave="test.enclave.sh", + measurement={"snp_measurement": "expected_snp_measurement"} + ) + + with pytest.raises(ValueError, match="SNP measurement mismatch"): + client.verify() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_verification_failures_integration.py b/tests/test_verification_failures_integration.py new file mode 100644 index 0000000..5b99003 --- /dev/null +++ b/tests/test_verification_failures_integration.py @@ -0,0 +1,143 @@ +""" +Integration tests for verification failure handling. + +These tests verify that REAL verification failures are properly caught +at the integration level - not with mocks, but with actual enclaves +and mismatched repos. + +This guards against bugs like the Go verifier issue where hardware +measurement mismatches would silently continue instead of failing. +""" + +import pytest + +from tinfoil.client import SecureClient, get_router_address +from tinfoil.attestation import MeasurementMismatchError + +pytestmark = pytest.mark.integration + +# Router runs confidential-model-router, use gpt-oss as wrong repo +CORRECT_REPO = "tinfoilsh/confidential-model-router" +WRONG_REPO = "tinfoilsh/confidential-gpt-oss-120b-free" + + +@pytest.fixture(scope="module") +def router_enclave(): + """Fetch a router enclave address, skip all tests if unavailable.""" + try: + return get_router_address() + except Exception as e: + pytest.skip(f"Could not fetch router address: {e}") + + +class TestMeasurementMismatchIntegration: + """ + Tests that measurement mismatches between enclave and repo + are properly caught at the integration level. + """ + + def test_wrong_repo_fails_verification(self, router_enclave): + """ + Verifying an enclave against the WRONG repo must fail. + + This test: + 1. Gets a real router enclave + 2. Tries to verify it against gpt-oss-120b-free repo (WRONG) + 3. Expects MeasurementMismatchError + + This catches bugs where measurement comparison is skipped. + """ + print(f"\nTesting: {router_enclave}") + print(f"Against WRONG repo: {WRONG_REPO}") + print("Expected: MeasurementMismatchError") + + client = SecureClient(enclave=router_enclave, repo=WRONG_REPO) + + with pytest.raises(MeasurementMismatchError): + client.verify() + + print("✓ Correctly rejected mismatched measurements") + + def test_wrong_repo_blocks_http_client(self, router_enclave): + """ + make_secure_http_client() must fail if measurements don't match. + + This catches bugs where HTTP client is created despite failed verification. + """ + client = SecureClient(enclave=router_enclave, repo=WRONG_REPO) + + with pytest.raises(MeasurementMismatchError): + client.make_secure_http_client() + + # Ground truth should NOT be set if verification failed + assert client.ground_truth is None, \ + "ground_truth was set despite verification failure!" + + print("✓ HTTP client correctly blocked on mismatch") + + def test_correct_repo_passes_verification(self, router_enclave): + """ + Sanity check: correct repo should pass verification. + """ + client = SecureClient(enclave=router_enclave, repo=CORRECT_REPO) + ground_truth = client.verify() + + assert ground_truth is not None + assert ground_truth.public_key is not None + assert ground_truth.measurement is not None + + print(f"✓ Correct repo verified successfully") + print(f" Measurement: {ground_truth.measurement.fingerprint()[:32]}...") + + +class TestDirectMeasurementIntegration: + """ + Tests for the direct measurement verification path. + """ + + def test_wrong_pinned_measurement_fails(self, router_enclave): + """ + If user provides a specific measurement that doesn't match + the enclave, verification must fail. + """ + # This is clearly a fake measurement + fake_measurement = { + "snp_measurement": "0000000000000000000000000000000000000000000000000000000000000000" + } + + client = SecureClient(enclave=router_enclave, measurement=fake_measurement) + + with pytest.raises(ValueError, match="measurement mismatch"): + client.verify() + + print("✓ Correctly rejected fake pinned measurement") + + +class TestNoSilentFailures: + """ + Tests that verification failures are NEVER silently ignored. + """ + + def test_verification_required_before_request(self, router_enclave): + """ + Any attempt to use the client must trigger verification. + If verification would fail, the request must also fail. + """ + client = SecureClient(enclave=router_enclave, repo=WRONG_REPO) + + # Try to get HTTP client - should fail + with pytest.raises(MeasurementMismatchError): + client.get_http_client() + + # Try to make request - should also fail + import urllib.request + req = urllib.request.Request(f"https://{router_enclave}/health") + + with pytest.raises(MeasurementMismatchError): + client.make_request(req) + + print("✓ All client methods properly block on verification failure") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From 441dd2362593b4945d80528c97c4578a3e15005b Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:15:05 -0500 Subject: [PATCH 08/19] refactor: isolate SEV attestation into dedicated modules Move SEV-specific logic into attestation_sev.py, rename abi_sevsnp.py to abi_sev.py and verify.py to verify_sev.py. Add SevAttestationError, report_data bounds checking, frozen ValidationOptions, and use pyasn1 for DER decoding in VCEK URL construction. --- src/tinfoil/attestation/abi_sev.py | 449 ++++++++++++++++++ src/tinfoil/attestation/attestation_sev.py | 159 +++++++ src/tinfoil/attestation/validate_sev.py | 288 ++++++++++++ src/tinfoil/attestation/verify_sev.py | 515 +++++++++++++++++++++ tests/test_sev_validation.py | 6 +- 5 files changed, 1414 insertions(+), 3 deletions(-) create mode 100644 src/tinfoil/attestation/abi_sev.py create mode 100644 src/tinfoil/attestation/attestation_sev.py create mode 100644 src/tinfoil/attestation/validate_sev.py create mode 100644 src/tinfoil/attestation/verify_sev.py diff --git a/src/tinfoil/attestation/abi_sev.py b/src/tinfoil/attestation/abi_sev.py new file mode 100644 index 0000000..cb9e079 --- /dev/null +++ b/src/tinfoil/attestation/abi_sev.py @@ -0,0 +1,449 @@ +from dataclasses import dataclass +from enum import IntEnum + +POLICY_RESERVED_1_BIT = 17 +REPORT_SIZE = 0x4A0 # 1184 bytes +SIGNATURE_OFFSET = 0x2A0 +ECDSA_RS_SIZE = 72 +ECDSA_P384_SHA384_SIGNATURE_SIZE = ECDSA_RS_SIZE + ECDSA_RS_SIZE + +ZEN3ZEN4_FAMILY = 0x19 +ZEN5_FAMILY = 0x1A +MILAN_MODEL = 0 | 1 +GENOA_MODEL = (1 << 4) | 1 +TURIN_MODEL = 2 + +class ReportSigner(IntEnum): + VcekReportSigner = 0 + # VlekReportSigner is the SIGNING_KEY value for if the VLEK signed the attestation report. + VlekReportSigner = 1 + endorseReserved2 = 2 + endorseReserved3 = 3 + endorseReserved4 = 4 + endorseReserved5 = 5 + endorseReserved6 = 6 + # NoneReportSigner is the SIGNING_KEY value for if the attestation report is not signed. + NoneReportSigner = 7 + +# SignerInfo represents information about the signing circumstances for the attestation report. +class SignerInfo: + # SigningKey represents kind of key by which a report was signed. + signingKey: ReportSigner + # MaskChipKey is true if the host chose to enable CHIP_ID masking, to cause the report's CHIP_ID + # to be all zeros. + maskChipKey: bool + # AuthorKeyEn is true if the VM is launched with an IDBLOCK that includes an author key. + authorKeyEn: bool + +@dataclass +class TCBParts: + """Represents the decomposed parts of a TCB version""" + ucode_spl: int + snp_spl: int + tee_spl: int + bl_spl: int + + def __str__(self) -> str: + """Return a human-friendly string with all component SPL values.""" + # Print fields in order starting with the least-significant component (bl_spl) + return ( + "TCBParts(" # Opening + f"bl_spl=0x{self.bl_spl:02x}, " + f"tee_spl=0x{self.tee_spl:02x}, " + f"snp_spl=0x{self.snp_spl:02x}, " + f"ucode_spl=0x{self.ucode_spl:02x})" + ) + + @classmethod + def from_int(cls, tcb: int) -> "TCBParts": + """Build a TCBParts instance from a 64-bit packed TCB value.""" + return cls( + ucode_spl=((tcb >> 56) & 0xff), + snp_spl=((tcb >> 48) & 0xff), + tee_spl=((tcb >> 8) & 0xff), + bl_spl=((tcb >> 0) & 0xff), + ) + + def meets_minimum(self, minimum: "TCBParts") -> bool: + """Check if this TCB meets minimum requirements (component-wise).""" + return ( + self.bl_spl >= minimum.bl_spl and + self.tee_spl >= minimum.tee_spl and + self.snp_spl >= minimum.snp_spl and + self.ucode_spl >= minimum.ucode_spl + ) + +@dataclass +class SnpPlatformInfo: + """Decoded view of the 64-bit PLATFORM_INFO field.""" + + smt_enabled: bool + tsme_enabled: bool + ecc_enabled: bool + rapl_disabled: bool + ciphertext_hiding_dram_enabled: bool + alias_check_complete: bool + tio_enabled: bool + + @classmethod + def from_int(cls, value: int) -> "SnpPlatformInfo": + return cls( + smt_enabled=bool(value & (1 << 0)), + tsme_enabled=bool(value & (1 << 1)), + ecc_enabled=bool(value & (1 << 2)), + rapl_disabled=bool(value & (1 << 3)), + ciphertext_hiding_dram_enabled=bool(value & (1 << 4)), + alias_check_complete=bool(value & (1 << 5)), + tio_enabled=bool(value & (1 << 7)) + ) + + def __str__(self) -> str: # pragma: no cover – formatting helper + return ( + "SnpPlatformInfo(" # opening + f"SMTEnabled={self.smt_enabled}, " + f"TSMEEnabled={self.tsme_enabled}, " + f"ECCEnabled={self.ecc_enabled}, " + f"RAPLDisabled={self.rapl_disabled}, " + f"CiphertextHidingDRAMEnabled={self.ciphertext_hiding_dram_enabled}, " + f"AliasCheckComplete={self.alias_check_complete}, " + f"TIOEnabled={self.tio_enabled})" + ) + + +@dataclass +class SnpPolicy: + """Decoded view of the 64-bit POLICY field (bits 0-20).""" + + abi_minor: int + abi_major: int + smt: bool + migrate_ma: bool + debug: bool + single_socket: bool + cxl_allowed: bool + mem_aes256_xts: bool + rapl_dis: bool + ciphertext_hiding_dram: bool + page_swap_disabled: bool + + def __str__(self) -> str: # pragma: no cover – formatting helper + return ( + "SnpPolicy(" # opening + f"ABIMajor={self.abi_major}, ABIMinor={self.abi_minor}, " + f"SMT={self.smt}, MigrateMA={self.migrate_ma}, Debug={self.debug}, " + f"SingleSocket={self.single_socket}, CXLAllowed={self.cxl_allowed}, " + f"MemAES256XTS={self.mem_aes256_xts}, RAPLDis={self.rapl_dis}, " + f"CipherTextHidingDRAM={self.ciphertext_hiding_dram}, PageSwapDisabled={self.page_swap_disabled})" + ) + + @classmethod + def from_int(cls, value: int) -> "SnpPolicy": + """Parse the guest policy bit-field following AMD SEV-SNP spec.""" + return cls( + abi_minor=value & 0xFF, + abi_major=(value >> 8) & 0xFF, + smt=bool(value & (1 << 16)), + migrate_ma=bool(value & (1 << 18)), + debug=bool(value & (1 << 19)), + single_socket=bool(value & (1 << 20)), + cxl_allowed=bool(value & (1 << 21)), + mem_aes256_xts=bool(value & (1 << 22)), + rapl_dis=bool(value & (1 << 23)), + ciphertext_hiding_dram=bool(value & (1 << 24)), + page_swap_disabled=bool(value & (1 << 25)), + ) + +@dataclass +class Report: + """SEV-SNP attestation report""" + version: int # Should be 2 for revision 1.55, 3 for revision 1.56, 5 for revision 1.58 + guest_svn: int + policy: int + policy_parsed: SnpPolicy + family_id: bytes # Should be 16 bytes long + image_id: bytes # Should be 16 bytes long + vmpl: int + signature_algo: int + current_tcb: int + platform_info: int + platform_info_parsed: SnpPlatformInfo + signer_info: int # AuthorKeyEn, MaskChipKey, SigningKey + signer_info_parsed: SignerInfo + report_data: bytes # Should be 64 bytes long + measurement: bytes # Should be 48 bytes long + host_data: bytes # Should be 32 bytes long + id_key_digest: bytes # Should be 48 bytes long + author_key_digest: bytes # Should be 48 bytes long + report_id: bytes # Should be 32 bytes long + report_id_ma: bytes # Should be 32 bytes long + reported_tcb: int + chip_id: bytes # Should be 64 bytes long + committed_tcb: int + current_build: int + current_minor: int + current_major: int + committed_build: int + committed_minor: int + committed_major: int + launch_tcb: int + signed_data: bytes + signature: bytes # Should be 512 bytes long + family: bytes + model: bytes + stepping: bytes + productName: str + + def __init__(self, data: bytes): + """ + Parse an attestation report from raw bytes in SEV SNP ABI format. + + Args: + data: Raw bytes of the attestation report + Returns: + Report object containing parsed data + """ + + if len(data) < REPORT_SIZE: + raise ValueError(f"Array size is 0x{len(data):x}, an SEV-SNP attestation report size is 0x{REPORT_SIZE:x}") + + # Parse all fields using little-endian byte order + self.version = int.from_bytes(data[0x00:0x04], byteorder='little') + self.guest_svn = int.from_bytes(data[0x04:0x08], byteorder='little') + self.policy = int.from_bytes(data[0x08:0x10], byteorder='little') + + # Check reserved bit must be 1 + if not (self.policy & (1 << POLICY_RESERVED_1_BIT)): + raise ValueError(f"policy[{POLICY_RESERVED_1_BIT}] is reserved, must be 1, got 0") + + # Check bits 63-26 must be zero + if self.policy >> 26: + raise ValueError("policy bits 63-26 must be zero") + + self.family_id = data[0x10:0x20] # 16 bytes + self.image_id = data[0x20:0x30] # 16 bytes + self.vmpl = int.from_bytes(data[0x30:0x34], byteorder='little') + self.signature_algo = int.from_bytes(data[0x34:0x38], byteorder='little') + self.current_tcb = int.from_bytes(data[0x38:0x40], byteorder='little') + + try: + mbz64(int(self.current_tcb), "current_tcb", 47, 16) + except ValueError as e: + raise ValueError(f"current_tcb not correctly formed: {e}") + + self.platform_info = int.from_bytes(data[0x40:0x48], byteorder='little') + # Decode additional helper structures for easier consumption later. + self.policy_parsed = SnpPolicy.from_int(self.policy) + self.platform_info_parsed = SnpPlatformInfo.from_int(self.platform_info) + self.signer_info = int.from_bytes(data[0x48:0x4C], byteorder='little') + + self.signer_info_parsed = SignerInfo() + try: + mbz64(int(self.signer_info), "signer_info", 31, 5) + except ValueError as e: + raise ValueError(f"signer_info not correctly formed: {e}") + + self.signer_info_parsed.signingKey = ReportSigner((self.signer_info >> 2) & 7) + if self.signer_info_parsed.signingKey != ReportSigner.VcekReportSigner: + raise ValueError(f"This implementation only supports VCEK signed reports. Got {self.signer_info_parsed.signingKey}") + self.signer_info_parsed.maskChipKey = (self.signer_info & 2) != 0 + self.signer_info_parsed.authorKeyEn = (self.signer_info & 1) != 0 + + try: + mbz(data, 0x4C, 0x50) + except ValueError as e: + raise ValueError(f"report_data not correctly formed: {e}") + + # 0x4C-0x50 is MBZ (Must Be Zero) + self.report_data = data[0x50:0x90] # 64 bytes + self.measurement = data[0x90:0xC0] # 48 bytes + self.host_data = data[0xC0:0xE0] # 32 bytes + self.id_key_digest = data[0xE0:0x110] # 48 bytes + self.author_key_digest = data[0x110:0x140] # 48 bytes + self.report_id = data[0x140:0x160] # 32 bytes + self.report_id_ma = data[0x160:0x180] # 32 bytes + self.reported_tcb = int.from_bytes(data[0x180:0x188], byteorder='little') + + try: + mbz64(int(self.reported_tcb), "reported_tcb", 47, 16) + except ValueError as e: + raise ValueError(f"reported_tcb not correctly formed: {e}") + + mbzLo = 0x188 + # Version specific parsing + if self.version >= 3: # Report Version 3 + self.family = data[0x188] + self.model = data[0x189] + self.stepping = data[0x18A] + self._init_product_name() + mbzLo = 0x18B + elif self.version == 2: # Report Version 2 + self.family = ZEN3ZEN4_FAMILY + self.model = GENOA_MODEL + self.stepping = 0x01 + self.productName = "Genoa" + else: + raise ValueError("Unknown report version") + + try: + mbz(data, mbzLo, 0x1A0) + except ValueError as e: + raise ValueError(f"report_data not correctly formed: {e}") + + self.chip_id = data[0x1A0:0x1E0] # 64 bytes + self.committed_tcb = int.from_bytes(data[0x1E0:0x1E8], byteorder='little') + + try: + mbz64(int(self.committed_tcb), "committed_tcb", 47, 16) + except ValueError as e: + raise ValueError(f"committed_tcb not correctly formed: {e}") + + # Version fields + self.current_build = data[0x1E8] + self.current_minor = data[0x1E9] + self.current_major = data[0x1EA] + + try: + mbz(data, 0x1EB, 0x1EC) + except ValueError as e: + raise ValueError(f"report_data not correctly formed: {e}") + + self.committed_build = data[0x1EC] + self.committed_minor = data[0x1ED] + self.committed_major = data[0x1EE] + + try: + mbz(data, 0x1EF, 0x1F0) + except ValueError as e: + raise ValueError(f"report_data not correctly formed: {e}") + + self.launch_tcb = int.from_bytes(data[0x1F0:0x1F8], byteorder='little') + + try: + mbz64(int(self.launch_tcb), "launch_tcb", 47, 16) + except ValueError as e: + raise ValueError(f"launch_tcb not correctly formed: {e}") + + try: + mbz(data, 0x1F8, SIGNATURE_OFFSET) + except ValueError as e: + raise ValueError(f"report_data not correctly formed: {e}") + + if self.signature_algo == 1: # ECDSA P-384 SHA-384 + try: + mbz(data, SIGNATURE_OFFSET+ECDSA_P384_SHA384_SIGNATURE_SIZE, REPORT_SIZE) + except ValueError as e: + raise ValueError(f"report_data not correctly formed: {e}") + + self.signed_data = data[0:SIGNATURE_OFFSET] + self.signature = data[SIGNATURE_OFFSET:REPORT_SIZE] + + def get_fms(self): + return self.family, self.model, self.stepping + + def _init_product_name(self): + # Combined extended values + self.productName = "Unknown" + if self.family == ZEN3ZEN4_FAMILY: + if self.model == MILAN_MODEL: + self.productName = "Milan" + elif self.model == GENOA_MODEL: + self.productName = "Genoa" + elif self.family == ZEN5_FAMILY: + if self.model == TURIN_MODEL: + self.productName = "Turin" + + def print_report(self): + """Print all relevant fields of the SEV-SNP attestation report in a human-readable format.""" + print("=== SEV-SNP Attestation Report ===") + print(f"Version: {self.version}") + print(f"Guest SVN: {self.guest_svn}") + print(f"Policy: 0x{self.policy:x}") + print(f" -> {self.policy_parsed}") + print(f"Family ID: {self.family_id.hex()}") + print(f"Image ID: {self.image_id.hex()}") + print(f"VMPL: {self.vmpl}") + print(f"Signature Algorithm: {self.signature_algo}") + print(f"Current TCB: 0x{self.current_tcb:x}") + print(f" -> {TCBParts.from_int(self.current_tcb)}") + print(f"Platform Info: 0x{self.platform_info:x}") + print(f" -> {self.platform_info_parsed}") + print(f"Signer Info: 0x{self.signer_info:x}") + print(f" - Signing Key: {self.signer_info_parsed.signingKey}") + print(f" - Mask Chip Key: {self.signer_info_parsed.maskChipKey}") + print(f" - Author Key Enabled: {self.signer_info_parsed.authorKeyEn}") + print(f"Report Data: {self.report_data.hex()}") + print(f"Measurement: {self.measurement.hex()}") + print(f"Host Data: {self.host_data.hex()}") + print(f"ID Key Digest: {self.id_key_digest.hex()}") + print(f"Author Key Digest: {self.author_key_digest.hex()}") + print(f"Report ID: {self.report_id.hex()}") + print(f"Report ID MA: {self.report_id_ma.hex()}") + print(f"Reported TCB: 0x{self.reported_tcb:x}") + print(f" -> {TCBParts.from_int(self.reported_tcb)}") + print(f"Chip ID: {self.chip_id.hex()}") + print(f"Committed TCB: 0x{self.committed_tcb:x}") + print(f" -> {TCBParts.from_int(self.committed_tcb)}") + print(f"Current Version: {self.current_major}.{self.current_minor}.{self.current_build}") + print(f"Committed Version: {self.committed_major}.{self.committed_minor}.{self.committed_build}") + print(f"Launch TCB: 0x{self.launch_tcb:x}") + print(f" -> {TCBParts.from_int(self.launch_tcb)}") + print(f"Product Name: {self.productName}") + if hasattr(self, 'family') and hasattr(self, 'model') and hasattr(self, 'stepping'): + print(f"CPU: Family=0x{self.family:02x}, Model=0x{self.model:02x}, Stepping=0x{self.stepping:02x}") + print(f"Signature Length: {len(self.signature)} bytes") + print("=" * 40) + +## HELPER FUNCTIONS + +def find_non_zero(data: bytes, lo: int, hi: int) -> int: + """ + Returns the first index which is not zero, otherwise returns hi. + + Args: + data: Bytes object to search through + lo: Starting index (inclusive) + hi: Ending index (exclusive) + Returns: + Index of first non-zero byte, or hi if all bytes are zero + """ + for i in range(lo, hi): + if data[i] != 0: + return i + return hi + +def mbz(data: bytes, lo: int, hi: int) -> None: + """ + Checks if a range of bytes is all zeros. + + Args: + data: Bytes object to check + lo: Starting index (inclusive) + hi: Ending index (exclusive) + Raises: + ValueError: If any byte in the range is non-zero + """ + first_non_zero = find_non_zero(data, lo, hi) + if first_non_zero != hi: + # Convert the slice to hex string for error message + hex_str = data[lo:hi].hex() + raise ValueError(f"mbz range [0x{lo:x}:0x{hi:x}] not all zero: {hex_str}") + +def mbz64(data: int, base: str, hi: int, lo: int) -> None: + """ + Checks if a range of bits in an integer is all zeros. + + Args: + data: Integer to check + base: String identifier for error message + hi: Highest bit position (inclusive) + lo: Lowest bit position (inclusive) + Raises: + ValueError: If any bit in the range is non-zero + """ + # Create mask for the bit range + mask = (1 << (hi - lo + 1)) - 1 + # Extract and check the bits + bits = (data >> lo) & mask + if bits != 0: + raise ValueError(f"mbz range {base}[0x{lo:x}:0x{hi:x}] not all zero: {hex(data)}") + diff --git a/src/tinfoil/attestation/attestation_sev.py b/src/tinfoil/attestation/attestation_sev.py new file mode 100644 index 0000000..932b0ed --- /dev/null +++ b/src/tinfoil/attestation/attestation_sev.py @@ -0,0 +1,159 @@ +""" +AMD SEV-SNP Attestation Orchestration Module. + +This module provides the high-level entry point for AMD SEV-SNP attestation verification. + +""" + +import base64 +from typing import Optional + +from .abi_sev import TCBParts, SnpPolicy, SnpPlatformInfo, Report +from .types import Measurement, Verification, PredicateType, TLS_KEY_FP_SIZE, HPKE_KEY_SIZE, safe_gzip_decompress +from .verify_sev import verify_attestation, CertificateChain +from .validate_sev import validate_report, ValidationOptions + + +class SevAttestationError(Exception): + """Raised when SEV-SNP attestation verification fails.""" + pass + + +# ============================================================================= +# Orchestration Constants +# ============================================================================= + +# Minimum TCB requirements for AMD SEV-SNP +min_tcb = TCBParts( + bl_spl=0x7, + tee_spl=0, + snp_spl=0xe, + ucode_spl=0x48, +) + +# Default validation options for AMD SEV-SNP attestation +default_validation_options = ValidationOptions( + guest_policy=SnpPolicy( + abi_minor=0, + abi_major=0, + smt=True, + migrate_ma=False, + debug=False, + single_socket=False, + cxl_allowed=False, + mem_aes256_xts=False, + rapl_dis=False, + ciphertext_hiding_dram=False, + page_swap_disabled=False, + ), + minimum_guest_svn=0, + minimum_build=21, + minimum_version=(1 << 8) | 55, # 1.55 + minimum_tcb=min_tcb, + minimum_launch_tcb=min_tcb, + permit_provisional_firmware=False, + platform_info=SnpPlatformInfo( + smt_enabled=True, + tsme_enabled=False, + ecc_enabled=False, + rapl_disabled=False, + ciphertext_hiding_dram_enabled=False, + alias_check_complete=False, + tio_enabled=False, + ), + require_author_key=False, + require_id_block=False, +) + + +# ============================================================================= +# Main Entry Points +# ============================================================================= + +def verify_sev_attestation_v2(attestation_doc: str) -> Verification: + """Verify SEV attestation document (v2 format) and return verification result. + + Raises: + ValueError: If verification fails + """ + try: + report = verify_sev_report(attestation_doc) + except SevAttestationError as e: + raise ValueError(f"SEV attestation verification failed: {e}") from e + + # Create measurement object + measurement = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=[ + report.measurement.hex() + ] + ) + + keys = report.report_data + required_len = TLS_KEY_FP_SIZE + HPKE_KEY_SIZE + if len(keys) < required_len: + raise ValueError( + f"report_data too short: {len(keys)} bytes, need at least {required_len}" + ) + tls_key_fp = keys[0:TLS_KEY_FP_SIZE] + hpke_public_key = keys[TLS_KEY_FP_SIZE:TLS_KEY_FP_SIZE + HPKE_KEY_SIZE] + + return Verification( + measurement=measurement, + public_key_fp=tls_key_fp.hex(), + hpke_public_key=hpke_public_key.hex() + ) + + +def verify_sev_report( + attestation_doc: str, + is_compressed: bool = True, + validation_options: Optional[ValidationOptions] = None, +) -> Report: + """Verify SEV attestation document and return the parsed report. + + Args: + attestation_doc: Base64-encoded attestation document + is_compressed: Whether the document is gzip-compressed (default True) + validation_options: Custom validation options; uses module defaults if None + """ + options = validation_options if validation_options is not None else default_validation_options + + try: + att_doc_bytes = base64.b64decode(attestation_doc) + except Exception as e: + raise SevAttestationError(f"Failed to decode base64: {e}") from e + + if is_compressed: + try: + att_doc_bytes = safe_gzip_decompress(att_doc_bytes) + except SevAttestationError: + raise + except Exception as e: + raise SevAttestationError(f"Failed to decompress attestation document: {e}") from e + + try: + report = Report(att_doc_bytes) + except Exception as e: + raise SevAttestationError(f"Failed to parse report: {e}") from e + + try: + chain = CertificateChain.from_report(report) + except Exception as e: + raise SevAttestationError(f"Failed to build certificate chain: {e}") from e + + try: + res = verify_attestation(chain, report) + except Exception as e: + raise SevAttestationError(f"Failed to verify attestation: {e}") from e + + if not res: + raise SevAttestationError("Attestation verification failed!") + + try: + validate_report(report, chain, options) + except Exception as e: + raise SevAttestationError(f"Failed to validate report: {e}") from e + + return report + diff --git a/src/tinfoil/attestation/validate_sev.py b/src/tinfoil/attestation/validate_sev.py new file mode 100644 index 0000000..6e911cd --- /dev/null +++ b/src/tinfoil/attestation/validate_sev.py @@ -0,0 +1,288 @@ +from dataclasses import dataclass, field +from typing import Optional, List, Dict + +from .abi_sev import ( + Report, + TCBParts, + SnpPolicy, + SnpPlatformInfo, + ReportSigner, +) +from .verify_sev import CertificateChain + +@dataclass(frozen=True) +class ValidationOptions: + """ + Verification options for an SEV-SNP attestation report. + Any attribute left as ``None`` / empty will not be checked + by the validation routine. + """ + # Policy / version constraints + guest_policy: Optional[SnpPolicy] = None + minimum_guest_svn: Optional[int] = None + minimum_build: Optional[int] = None # Firmware build (uint8) + minimum_version: Optional[int] = None # Firmware API version (uint16) + + # TCB requirements + minimum_tcb: Optional[TCBParts] = None + minimum_launch_tcb: Optional[TCBParts] = None + permit_provisional_firmware: bool = False + + # Field equality checks (length is not enforced here; caller must ensure correctness) + report_data: Optional[bytes] = None # 64 bytes + host_data: Optional[bytes] = None # 32 bytes + image_id: Optional[bytes] = None # 16 bytes + family_id: Optional[bytes] = None # 16 bytes + report_id: Optional[bytes] = None # 32 bytes + report_id_ma: Optional[bytes] = None # 32 bytes + measurement: Optional[bytes] = None # 48 bytes + chip_id: Optional[bytes] = None # 64 bytes + + # Misc + platform_info: Optional[SnpPlatformInfo] = None + vmpl: Optional[int] = None # Expected VMPL (0-3) + + # TODO: ID-block / author key requirements + require_author_key: bool = False + require_id_block: bool = False + # trusted_author_keys: List[x509.Certificate] = field(default_factory=list) + # trusted_author_key_hashes: List[bytes] = field(default_factory=list) + # trusted_id_keys: List[x509.Certificate] = field(default_factory=list) + # trusted_id_key_hashes: List[bytes] = field(default_factory=list) + + # TODO:Extended certificate-table options + # cert_table_options: Dict[str, CertEntryOption] = field(default_factory=dict) + + +def validate_report(report: Report, chain: CertificateChain, options: ValidationOptions): + """ + Validate the supplied SEV-SNP attestation report according to *options*. + Raises ValueError if validation fails. + """ + + # Policy constraints + if options.guest_policy is not None: + _validate_policy(report.policy_parsed, options.guest_policy) + + if options.minimum_guest_svn is not None: + if report.guest_svn < options.minimum_guest_svn: + raise ValueError(f"Guest SVN {report.guest_svn} is less than minimum required {options.minimum_guest_svn}") + + if options.minimum_build is not None: + if report.current_build < options.minimum_build: + raise ValueError(f"Current SNP firmware build number {report.current_build} is less than minimum required {options.minimum_build}") + if report.committed_build < options.minimum_build: + raise ValueError(f"Committed SNP firmware build number {report.committed_build} is less than minimum required {options.minimum_build}") + + if options.minimum_version is not None: + # Combine major/minor into single version number for comparison + current_version = (report.current_major << 8) | report.current_minor + committed_version = (report.committed_major << 8) | report.committed_minor + if current_version < options.minimum_version: + raise ValueError(f"Current SNP firmwareversion {report.current_major}.{report.current_minor} is less than minimum required {options.minimum_version >> 8}.{options.minimum_version & 0xff}") + if committed_version < options.minimum_version: + raise ValueError(f"Committed SNP firmware version {report.committed_major}.{report.committed_minor} is less than minimum required {options.minimum_version >> 8}.{options.minimum_version & 0xff}") + + # TCB requirements + if options.minimum_tcb is not None: + current_tcb_parts = TCBParts.from_int(report.current_tcb) + committed_tcb_parts = TCBParts.from_int(report.committed_tcb) + reported_tcb_parts = TCBParts.from_int(report.reported_tcb) + if not current_tcb_parts.meets_minimum(options.minimum_tcb): + raise ValueError(f"Current TCB {current_tcb_parts} does not meet minimum requirements {options.minimum_tcb}") + if not committed_tcb_parts.meets_minimum(options.minimum_tcb): + raise ValueError(f"Committed TCB {committed_tcb_parts} does not meet minimum requirements {options.minimum_tcb}") + if not reported_tcb_parts.meets_minimum(options.minimum_tcb): + raise ValueError(f"Reported TCB {reported_tcb_parts} does not meet minimum requirements {options.minimum_tcb}") + + # VCEK-specific TCB check + chain.validate_vcek_tcb(TCBParts.from_int(report.reported_tcb)) + + if options.minimum_launch_tcb is not None: + launch_tcb_parts = TCBParts.from_int(report.launch_tcb) + if not launch_tcb_parts.meets_minimum(options.minimum_launch_tcb): + raise ValueError(f"Launch TCB {launch_tcb_parts} does not meet minimum requirements {options.minimum_launch_tcb}") + + # Field equality checks + if options.report_data is not None: + if len(report.report_data) != 64: + raise ValueError(f"Report data length is {len(report.report_data)}, expected 64 bytes") + if report.report_data != options.report_data: + raise ValueError(f"Report data mismatch: got {report.report_data.hex()}, expected {options.report_data.hex()}") + + if options.host_data is not None: + if len(report.host_data) != 32: + raise ValueError(f"Host data length is {len(report.host_data)}, expected 32 bytes") + if report.host_data != options.host_data: + raise ValueError(f"Host data mismatch: got {report.host_data.hex()}, expected {options.host_data.hex()}") + + if options.image_id is not None: + if len(report.image_id) != 16: + raise ValueError(f"Image ID length is {len(report.image_id)}, expected 16 bytes") + if report.image_id != options.image_id: + raise ValueError(f"Image ID mismatch: got {report.image_id.hex()}, expected {options.image_id.hex()}") + + if options.family_id is not None: + if len(report.family_id) != 16: + raise ValueError(f"Family ID length is {len(report.family_id)}, expected 16 bytes") + if report.family_id != options.family_id: + raise ValueError(f"Family ID mismatch: got {report.family_id.hex()}, expected {options.family_id.hex()}") + + if options.report_id is not None: + if len(report.report_id) != 32: + raise ValueError(f"Report ID length is {len(report.report_id)}, expected 32 bytes") + if report.report_id != options.report_id: + raise ValueError(f"Report ID mismatch: got {report.report_id.hex()}, expected {options.report_id.hex()}") + + if options.report_id_ma is not None: + if len(report.report_id_ma) != 32: + raise ValueError(f"Report ID MA length is {len(report.report_id_ma)}, expected 32 bytes") + if report.report_id_ma != options.report_id_ma: + raise ValueError(f"Report ID MA mismatch: got {report.report_id_ma.hex()}, expected {options.report_id_ma.hex()}") + + if options.measurement is not None: + if len(report.measurement) != 48: + raise ValueError(f"Measurement length is {len(report.measurement)}, expected 48 bytes") + if report.measurement != options.measurement: + raise ValueError(f"Measurement mismatch: got {report.measurement.hex()}, expected {options.measurement.hex()}") + + if options.chip_id is not None: + if len(report.chip_id) != 64: + raise ValueError(f"Chip ID length is {len(report.chip_id)}, expected 64 bytes") + if report.chip_id != options.chip_id: + raise ValueError(f"Chip ID mismatch: got {report.chip_id.hex()}, expected {options.chip_id.hex()}") + + # VCEK-specific CHIP_ID ↔ HWID equality check + if report.signer_info_parsed.signingKey == ReportSigner.VcekReportSigner: + if report.signer_info_parsed.maskChipKey and any(report.chip_id): + raise ValueError("maskChipKey is set but CHIP_ID is not zeroed") + if not report.signer_info_parsed.maskChipKey: + chain.validate_vcek_hwid(report.chip_id) + + # Platform info check + if options.platform_info is not None: + _validate_platform_info(report.platform_info_parsed, options.platform_info) + + # VMPL check + if options.vmpl is not None: # Must be between 0 and 3 and equal to the expected value + if not (0 <= report.vmpl <= 3): + raise ValueError(f"VMPL {report.vmpl} is not in valid range 0-3") + if report.vmpl != options.vmpl: + raise ValueError(f"VMPL mismatch: got {report.vmpl}, expected {options.vmpl}") + + # Provisional firmware check - we only support permit_provisional_firmware = False + if options.permit_provisional_firmware: + # Not supported - reject any request for provisional firmware + raise ValueError("Provisional firmware is not supported") + + # When permit_provisional_firmware = False, committed and current values must be equal + if report.committed_build != report.current_build: + raise ValueError(f"Committed build {report.committed_build} does not match current build {report.current_build}") + if report.committed_minor != report.current_minor: + raise ValueError(f"Committed minor version {report.committed_minor} does not match current minor version {report.current_minor}") + if report.committed_major != report.current_major: + raise ValueError(f"Committed major version {report.committed_major} does not match current major version {report.current_major}") + if report.committed_tcb != report.current_tcb: + raise ValueError(f"Committed TCB 0x{report.committed_tcb:x} does not match current TCB 0x{report.current_tcb:x}") + + # ID-block / author key requirements + if options.require_author_key or options.require_id_block: + # Not supported yet + raise ValueError("ID-block and author key requirements are not supported yet") + + +def _validate_policy(report_policy: SnpPolicy, required: SnpPolicy): + """ + Validate policy with security-aware checks. + + Logic follows Go reference implementation: + - Check ABI version compatibility + - Reject unauthorized capabilities (report has them, required doesn't allow) + - Reject missing required restrictions/features + + Raises ValueError if validation fails. + """ + # ABI version check - required version must not be greater than report version + if _compare_policy_versions(required, report_policy) > 0: + raise ValueError(f"Required ABI version ({required.abi_major}.{required.abi_minor}) is greater than report's ABI version ({report_policy.abi_major}.{report_policy.abi_minor})") + + # Unauthorized capabilities (report has them enabled, but required doesn't allow) + if not required.migrate_ma and report_policy.migrate_ma: + raise ValueError(f"Found unauthorized migration agent capability. Report policy: {report_policy}, Required policy: {required}") + + if not required.debug and report_policy.debug: + raise ValueError(f"Found unauthorized debug capability. Report policy: {report_policy}, Required policy: {required}") + + if not required.smt and report_policy.smt: + raise ValueError(f"Found unauthorized symmetric multithreading (SMT) capability. Report policy: {report_policy}, Required policy: {required}") + + if not required.cxl_allowed and report_policy.cxl_allowed: + raise ValueError(f"Found unauthorized CXL capability. Report policy: {report_policy}, Required policy: {required}") + + if not required.mem_aes256_xts and report_policy.mem_aes256_xts: + raise ValueError(f"Found unauthorized memory encryption mode. Report policy: {report_policy}, Required policy: {required}") + + # Required restrictions/features (report lacks what required mandates) + if required.single_socket and not report_policy.single_socket: + raise ValueError(f"Required single socket restriction not present. Report policy: {report_policy}, Required policy: {required}") + + if required.mem_aes256_xts and not report_policy.mem_aes256_xts: + raise ValueError(f"Found unauthorized memory encryption mode. Report policy: {report_policy}, Required policy: {required}") + + if required.rapl_dis and not report_policy.rapl_dis: + raise ValueError(f"Found unauthorized RAPL capability. Report policy: {report_policy}, Required policy: {required}") + + if required.ciphertext_hiding_dram and not report_policy.ciphertext_hiding_dram: + raise ValueError(f"Ciphertext hiding in DRAM isn't enforced. Report policy: {report_policy}, Required policy: {required}") + + if required.page_swap_disabled and not report_policy.page_swap_disabled: + raise ValueError(f"Page swap isn't disabled. Report policy: {report_policy}, Required policy: {required}") + +def _compare_policy_versions(required: SnpPolicy, report: SnpPolicy) -> int: + """ + Compare policy ABI versions. + Returns: + > 0 if required version is greater than report version + = 0 if versions are equal + < 0 if required version is less than report version + """ + # Compare major version first + if required.abi_major != report.abi_major: + return required.abi_major - report.abi_major + + # If major versions are equal, compare minor versions + return required.abi_minor - report.abi_minor + + +def _validate_platform_info(report_info: SnpPlatformInfo, required: SnpPlatformInfo): + """ + Validate platform info with security-aware checks. + + Logic follows Go reference implementation: + - If report has a feature enabled that required doesn't allow -> FAIL + - If report lacks a feature that required mandates -> FAIL + + Raises ValueError if validation fails. + """ + # Unauthorized features (report has it enabled, but required doesn't allow it) + if report_info.smt_enabled and not required.smt_enabled: + raise ValueError(f"Unauthorized platform feature SMT enabled. Report platform info: {report_info}, Required platform info: {required}") + + # Required features (report lacks something that required mandates) + if not report_info.ecc_enabled and required.ecc_enabled: + raise ValueError(f"Required platform feature ECC not enabled. Report platform info: {report_info}, Required platform info: {required}") + + if not report_info.tsme_enabled and required.tsme_enabled: + raise ValueError(f"Required platform feature TSME not enabled. Report platform info: {report_info}, Required platform info: {required}") + + if not report_info.rapl_disabled and required.rapl_disabled: + raise ValueError(f"Required platform feature RAPL not disabled. Report platform info: {report_info}, Required platform info: {required}") + + if not report_info.ciphertext_hiding_dram_enabled and required.ciphertext_hiding_dram_enabled: + raise ValueError(f"Required ciphertext hiding in DRAM not enforced. Report platform info: {report_info}, Required platform info: {required}") + + if not report_info.alias_check_complete and required.alias_check_complete: + raise ValueError(f"Required memory alias check hasn't been completed. Report platform info: {report_info}, Required platform info: {required}") + + if not report_info.tio_enabled and required.tio_enabled: + raise ValueError(f"Required TIO not enabled. Report platform info: {report_info}, Required platform info: {required}") \ No newline at end of file diff --git a/src/tinfoil/attestation/verify_sev.py b/src/tinfoil/attestation/verify_sev.py new file mode 100644 index 0000000..29319e8 --- /dev/null +++ b/src/tinfoil/attestation/verify_sev.py @@ -0,0 +1,515 @@ +#!/usr/bin/env python3 +""" +Simplified AMD SEV-SNP Attestation Verifier (VCEK Chain Only) +""" + +import os +from dataclasses import dataclass +from typing import Dict, TypeAlias +import binascii +import requests +from OpenSSL import crypto +import platformdirs + +from .abi_sev import (Report, ReportSigner, TCBParts) +from .genoa_cert_chain import (ARK_CERT, ASK_CERT) + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec, utils +from cryptography.x509.oid import ObjectIdentifier +from pyasn1.codec.der import decoder as der_decoder +from pyasn1.type.univ import Integer as Asn1Integer +from pyasn1.type.char import IA5String as Asn1IA5String +import warnings +from cryptography.utils import CryptographyDeprecationWarning + +# Type alias for certificate extensions +Extensions: TypeAlias = Dict[ObjectIdentifier, bytes] + +# VCEK cache directory (created lazily on first use) +_VCEK_CACHE_DIR = platformdirs.user_cache_dir("tinfoil", "tinfoil") + + +def _ensure_vcek_cache_dir() -> None: + """Create the VCEK cache directory if it doesn't exist. Silently ignores failures.""" + try: + os.makedirs(_VCEK_CACHE_DIR, exist_ok=True) + except OSError: + pass + +class SnpOid: + """OID extensions for the VCEK, used to verify attestation report""" + BOOTLOADER = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.1") + TEE = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.2") + SNP = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.3") + STRUCT_VERSION = ObjectIdentifier("1.3.6.1.4.1.3704.1.1") + PRODUCT_NAME_1 = ObjectIdentifier("1.3.6.1.4.1.3704.1.2") + BL_SPL = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.1") + TEE_SPL = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.2") + SNP_SPL = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.3") + SPL4 = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.4") + SPL5 = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.5") + SPL6 = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.6") + SPL7 = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.7") + UCODE = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.8") + HWID = ObjectIdentifier("1.3.6.1.4.1.3704.1.4") + CSP_ID = ObjectIdentifier("1.3.6.1.4.1.3704.1.5") + +class CertificateChain: + """Represents the SEV certificate chain (ARK > ASK > VCEK)""" + ark: x509.Certificate + ask: x509.Certificate + vcek: x509.Certificate + + def __init__(self, ark: x509.Certificate, ask: x509.Certificate, vcek: x509.Certificate): + self.ark = ark + self.ask = ask + self.vcek = vcek + + @staticmethod + def _vcek_cache_path(product_name: str, chip_id: bytes, reported_tcb: int) -> str: + """ + Build a deterministic filename for a given (product, chip_id, tcb). + Uses the module-level _VCEK_CACHE_DIR. + """ + chip_hex = chip_id.hex() + tcb_hex = f"{reported_tcb:016x}" + filename = f"VCEK_{product_name}_{chip_hex}_{tcb_hex}.der" + return os.path.join(_VCEK_CACHE_DIR, filename) + + @classmethod + def from_files(cls, ark_path: str, ask_path: str, vcek_path: str) -> 'CertificateChain': + """Alternative constructor to load certificates from files""" + ark = cls._load_cert(ark_path) + ask = cls._load_cert(ask_path) + vcek = cls._load_cert(vcek_path) + return cls(ark=ark, ask=ask, vcek=vcek) + + @classmethod + def from_report(cls, report:Report) -> 'CertificateChain': + productName: str = report.productName + + if productName != "Genoa": + raise ValueError("This implementation only supports Genoa processors") + + # Use the hardcoded certificate chain + ark = x509.load_pem_x509_certificate(ARK_CERT) + ask = x509.load_pem_x509_certificate(ASK_CERT) + + signer_info = report.signer_info_parsed + + if signer_info.signingKey != ReportSigner.VcekReportSigner: + raise ValueError("This implementation only supports VCEK signed reports") + + # Fetch (or load) the VCEK certificate + vcek_url = _VCEKCertURL(productName, report.chip_id, report.reported_tcb) + cache_path = cls._vcek_cache_path(productName, report.chip_id, report.reported_tcb) + + # 1. Try the on‑disk cache + if os.path.isfile(cache_path): + with open(cache_path, "rb") as fh: + vcek_cert_data = fh.read() + else: + # 2. Cache miss → fetch from the KDS endpoint + try: + response = requests.get(vcek_url, timeout=10) + response.raise_for_status() + vcek_cert_data = response.content + # Persist to cache so the next call is instant + _ensure_vcek_cache_dir() + try: + tmp_path = cache_path + ".tmp" + with open(tmp_path, "wb") as fh: + fh.write(vcek_cert_data) + os.replace(tmp_path, cache_path) + except OSError: + pass + except requests.RequestException as e: + raise ValueError(f"Failed to fetch VCEK certificate: {e}") from e + + # Parse the (cached or freshly‑downloaded) certificate + try: + # cryptography 46+ emits a deprecation warning for non‑positive serial numbers. + # Suppress this specific deprecation warning locally when parsing VCEK DER. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=r"Parsed a serial number which wasn't positive", + category=CryptographyDeprecationWarning, + ) + vcek = x509.load_der_x509_certificate(vcek_cert_data) + except Exception as e: + # Corrupted cache? Remove and propagate error so caller can retry. + if os.path.exists(cache_path): + try: + os.remove(cache_path) + except OSError: + pass + raise ValueError(f"Failed to parse VCEK certificate: {e}") from e + + # Return the complete certificate chain + return cls(ark=ark, ask=ask, vcek=vcek) + + @staticmethod + def _load_cert(filepath: str) -> x509.Certificate: + """Load an X.509 certificate from file""" + _, ext = os.path.splitext(filepath) + with open(filepath, 'rb') as f: + data = f.read() + + if ext.lower() == '.pem': + return x509.load_pem_x509_certificate(data) + else: + # Suppress cryptography deprecation warnings for DER parsing as above. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=r"Parsed a serial number which wasn't positive", + category=CryptographyDeprecationWarning, + ) + return x509.load_der_x509_certificate(data) + + def verify_chain(self) -> bool: + # Validate VCEK format + try: + self._validate_vcek_format() + except ValueError as e: + print(f"VCEK certificate validation failed: {e}") + return False + + # Validate ARK and ASK format + try: + self._validate_ark_format() + except ValueError as e: + print(f"ARK certificate validation failed: {e}") + return False + try: + self._validate_ask_format() + except ValueError as e: + print(f"ASK certificate validation failed: {e}") + return False + + # Verify the certificate chain using OpenSSL + try: + # Create a store and add the root (ARK) certificate + store = crypto.X509Store() + store.add_cert(crypto.X509.from_cryptography(self.ark)) + + # Add the intermediate (ASK) certificate + store.add_cert(crypto.X509.from_cryptography(self.ask)) + + # Create a store context + store_ctx = crypto.X509StoreContext(store, crypto.X509.from_cryptography(self.vcek)) + + # Verify the certificate + try: + store_ctx.verify_certificate() + return True + except crypto.X509StoreContextError as e: + print(f"Certificate chain verification failed: {e}") + return False + + except Exception as e: + print(f"Error during chain verification: {e}") + return False + + def _validate_ark_format(self): + if self.ark.version != x509.Version.v3: + raise ValueError("ARK certificate version is not 3") + if not _validateAmdLocation(self.ark.issuer): + raise ValueError("ARK certificate issuer is not a valid AMD location") + if not _validateAmdLocation(self.ark.subject): + raise ValueError("ARK certificate subject is not a valid AMD location") + + cn = self.ark.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value + if cn != "ARK-Genoa": + raise ValueError(f"ARK certificate subject common name is not ARK-Genoa but {cn}") + + # TODO add support for Certificate Revocation Lists + # NOTE Here the go implementation cross check Sev format with the X509 certificate but we only trust the certificate we ship with the code + + def _validate_ask_format(self): + # Validate ASK format + if self.ask.version != x509.Version.v3: + raise ValueError("ASK certificate version is not 3") + if not _validateAmdLocation(self.ask.issuer): + raise ValueError("ASK certificate issuer is not a valid AMD location") + if not _validateAmdLocation(self.ask.subject): + raise ValueError("ASK certificate subject is not a valid AMD location") + + cn = self.ask.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value + if cn != "SEV-Genoa": + raise ValueError(f"ASK certificate subject common name is not SEV-Genoa but {cn}") + + # TODO add support for Certificate Revocation Lists + # NOTE Here the go implementation cross check Sev format with the X509 certificate but we only trust the certificate we ship with the code + + def _validate_vcek_format(self): + """Validate the format of a VCEK certificate""" + + if self.vcek.version != x509.Version.v3: + raise ValueError(f"VCEK certificate version is not 3 but {self.vcek.version}") + + if self.vcek.signature_algorithm_oid != x509.SignatureAlgorithmOID.RSASSA_PSS: + raise ValueError(f"VCEK certificate signature algorithm is not RSASSA_PSS but {self.vcek.signature_algorithm_oid}") + + if self.vcek.public_key_algorithm_oid != x509.PublicKeyAlgorithmOID.EC_PUBLIC_KEY: + raise ValueError(f"VCEK certificate public key algorithm is not ECDSA but {self.vcek.public_key_algorithm_oid}") + + if self.vcek.public_key().curve.name != "secp384r1": + raise ValueError(f"VCEK certificate public key curve is not secp384r1 but {self.vcek.public_key().curve.name}") + + # Validate KDS Cert Subject + if not _validateAmdLocation(self.vcek.subject): + raise ValueError("VCEK certificate subject is not a valid AMD location") + + cn = self.vcek.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value + if cn != "SEV-VCEK": + raise ValueError(f"VCEK certificate subject common name is not SEV-VCEK but {cn}") + + # Get KDS and validate Cert Extensions + extensions = _get_certificate_extensions(self.vcek) + if SnpOid.CSP_ID in extensions: + raise ValueError(f"unexpected CSP_ID in VCEK certificate: {extensions[SnpOid.CSP_ID]}") + + if (SnpOid.HWID not in extensions) or len(extensions[SnpOid.HWID]) != 64: # ChipIDSize + raise ValueError(f"missing HWID extension for VCEK certificate") + + product_name = _decode_der_ia5string(extensions[SnpOid.PRODUCT_NAME_1]) + if product_name != "Genoa": + raise ValueError(f"unexpected PRODUCT_NAME_1 in VCEK certificate: {product_name}") + + def validate_vcek_tcb(self, tcb: TCBParts): + """Validate the TCB extension in the VCEK certificate matches a given TCB""" + extensions = _get_certificate_extensions(self.vcek) + + if SnpOid.BL_SPL not in extensions: + raise ValueError(f"missing BL_SPL extension for VCEK certificate") + bl_spl = _decode_der_integer(extensions[SnpOid.BL_SPL]) + if bl_spl != tcb.bl_spl: + raise ValueError(f"BL_SPL extension in VCEK certificate does not match tcb.bl_spl: {bl_spl} != {tcb.bl_spl}") + + if SnpOid.TEE_SPL not in extensions: + raise ValueError(f"missing TEE_SPL extension for VCEK certificate") + tee_spl = _decode_der_integer(extensions[SnpOid.TEE_SPL]) + if tee_spl != tcb.tee_spl: + raise ValueError(f"TEE_SPL extension in VCEK certificate does not match tcb.tee_spl: {tee_spl} != {tcb.tee_spl}") + + if SnpOid.SNP_SPL not in extensions: + raise ValueError(f"missing SNP_SPL extension for VCEK certificate") + snp_spl = _decode_der_integer(extensions[SnpOid.SNP_SPL]) + if snp_spl != tcb.snp_spl: + raise ValueError(f"SNP_SPL extension in VCEK certificate does not match tcb.snp_spl: {snp_spl} != {tcb.snp_spl}") + + if SnpOid.UCODE not in extensions: + raise ValueError(f"missing UCODE extension for VCEK certificate") + ucode_spl = _decode_der_integer(extensions[SnpOid.UCODE]) + if ucode_spl != tcb.ucode_spl: + raise ValueError(f"UCODE extension in VCEK certificate does not match tcb.ucode_spl: {ucode_spl} != {tcb.ucode_spl}") + + def validate_vcek_hwid(self, chip_id: bytes): + """Validate the HWID extension in the VCEK certificate matches a given chip id""" + extensions = _get_certificate_extensions(self.vcek) + if SnpOid.HWID not in extensions: + raise ValueError(f"missing HWID extension for VCEK certificate") + if extensions[SnpOid.HWID] != chip_id: + raise ValueError(f"HWID extension in VCEK certificate does not match chip_id: {extensions[SnpOid.HWID]} != {chip_id}") + + +## HELPER FUNCTIONS + +def _get_certificate_extensions(cert: x509.Certificate) -> Extensions: + """Get the extensions from the VCEK certificate""" + extensions = {} + for ext in cert.extensions: + extensions[ext.oid] = ext.value.value + return extensions + +def _decode_der_integer(der_bytes: bytes) -> int: + """Decode a DER-encoded INTEGER, rejecting non-INTEGER types and trailing bytes.""" + try: + result, remainder = der_decoder.decode(der_bytes, asn1Spec=Asn1Integer()) + except Exception as e: + raise ValueError(f"Failed to decode DER integer: {e}") from e + if remainder: + raise ValueError( + f"Unexpected trailing bytes after DER integer: {len(remainder)} bytes" + ) + return int(result) + +def _decode_der_ia5string(der_bytes: bytes) -> str: + """Decode a DER-encoded IA5String, rejecting non-IA5String types and trailing bytes.""" + try: + result, remainder = der_decoder.decode(der_bytes, asn1Spec=Asn1IA5String()) + except Exception as e: + raise ValueError(f"Failed to decode DER IA5String: {e}") from e + if remainder: + raise ValueError( + f"Unexpected trailing bytes after DER IA5String: {len(remainder)} bytes" + ) + return str(result) + +def _validateAmdLocation(name: x509.Name) -> bool: + """Validate that the certificate subject name matches AMD's expected values. + + Args: + name: The x509.Name object to validate + + Returns: + bool: True if all fields match expected values, False otherwise + """ + def check_singleton_list(values: list[str], field_name: str, expected: str) -> bool: + if len(values) != 1: + print(f"Expected exactly one {field_name}, got {len(values)}") + return False + if values[0] != expected: + print(f"Unexpected {field_name} value: '{values[0]}', expected '{expected}'") + return False + return True + + # Get the name attributes + country = name.get_attributes_for_oid(x509.NameOID.COUNTRY_NAME) + locality = name.get_attributes_for_oid(x509.NameOID.LOCALITY_NAME) + state = name.get_attributes_for_oid(x509.NameOID.STATE_OR_PROVINCE_NAME) + org = name.get_attributes_for_oid(x509.NameOID.ORGANIZATION_NAME) + org_unit = name.get_attributes_for_oid(x509.NameOID.ORGANIZATIONAL_UNIT_NAME) + + # Extract the values from the attributes + country_values = [attr.value for attr in country] + locality_values = [attr.value for attr in locality] + state_values = [attr.value for attr in state] + org_values = [attr.value for attr in org] + org_unit_values = [attr.value for attr in org_unit] + + # Validate each field + if not check_singleton_list(country_values, "country", "US"): + return False + if not check_singleton_list(locality_values, "locality", "Santa Clara"): + return False + if not check_singleton_list(state_values, "state", "CA"): + return False + if not check_singleton_list(org_values, "organization", "Advanced Micro Devices"): + return False + if not check_singleton_list(org_unit_values, "organizational unit", "Engineering"): + return False + + return True + +def _VCEKCertURL(productName: str, chip_id: bytes, reported_tcb: int) -> str: + # TODO add support for other product names + """Generate the VCEK certificate URL based on the product name, chip ID, and reported TCB""" + import urllib.parse + parts = TCBParts.from_int(reported_tcb) + base_url = "https://kds-proxy.tinfoil.sh/vcek/v1" + chip_id_hex = binascii.hexlify(chip_id).decode('ascii') + path = f"{base_url}/{urllib.parse.quote(productName, safe='')}/{chip_id_hex}" + query = urllib.parse.urlencode({ + "blSPL": parts.bl_spl, + "teeSPL": parts.tee_spl, + "snpSPL": parts.snp_spl, + "ucodeSPL": parts.ucode_spl, + }) + return f"{path}?{query}" + +def _verify_report_signature(vcek: x509.Certificate, report: Report) -> bool: + """Verify the attestation report signature using VCEK's public key""" + + # Validate Report Format + POLICY_RESERVED_1_BIT = 17 + + if report.version < 2: + raise ValueError(f"Report version is lower than 2: is {report.version}") + + # Check reserved bit must be 1 + if not (report.policy & (1 << POLICY_RESERVED_1_BIT)): + raise ValueError(f"policy[{POLICY_RESERVED_1_BIT}] is reserved, must be 1, got 0") + + # Check bits 63-26 must be zero + if report.policy >> 26: + raise ValueError("policy bits 63-26 must be zero") + + try: + # Check signature algorithm + if report.signature_algo != 1: # 1 = SignEcdsaP384Sha384 + print(f"Unknown SignatureAlgo: {report.signature_algo}") + return False + + # Verify the public key is an EC key + public_key = vcek.public_key() + if not isinstance(public_key, ec.EllipticCurvePublicKey): + print("VCEK doesn't contain an EC public key") + return False + + # Convert the raw signature to DER format + # The signature in the report is in raw R||S format in AMD's little-endian format + # Each component is 72 bytes (0x48) for P384 + r_bytes = bytes(reversed(report.signature[0:0x48])) # Reverse bytes for big-endian + s_bytes = bytes(reversed(report.signature[0x48:0x90])) # Reverse bytes for big-endian + + r = int.from_bytes(r_bytes.lstrip(b'\x00'), byteorder='big') + s = int.from_bytes(s_bytes.lstrip(b'\x00'), byteorder='big') + + der_signature = utils.encode_dss_signature(r, s) + + # Verify signature + public_key.verify( + der_signature, + report.signed_data, + ec.ECDSA(hashes.SHA384()) + ) + return True + except Exception as e: + print(f"Attestation signature verification failed: {e}") + return False + +def verify_attestation(chain: CertificateChain, report: Report) -> bool: + """Verify attestation report with the certificate chain""" + try: + # Verify certificate chain + if not chain.verify_chain(): + # Since verify_chain() already prints its own error messages + return False + + # Verify report + if not _verify_report_signature(chain.vcek, report): + return False + + return True + + except Exception as e: + print(f"Verification failed: {e}") + return False + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Simplified AMD SEV-SNP Attestation Verifier") + parser.add_argument("--ark", required=True, help="Path to ARK certificate") + parser.add_argument("--ask", required=True, help="Path to ASK certificate") + parser.add_argument("--vcek", required=True, help="Path to VCEK certificate") + parser.add_argument("--report", required=True, help="Path to attestation report") + + args = parser.parse_args() + + # Load certificate chain + chain = CertificateChain.from_files(args.ark, args.ask, args.vcek) + + # Read and parse attestation report + with open(args.report, 'rb') as f: + report_data = f.read() + + report = Report(report_data) + + result = verify_attestation(chain, report) + if result: + print("Attestation verification successful") + return 0 + else: + print("Attestation verification failed") + return 1 + + +if __name__ == "__main__": + import sys + sys.exit(main()) \ No newline at end of file diff --git a/tests/test_sev_validation.py b/tests/test_sev_validation.py index 9650f56..6ae1059 100644 --- a/tests/test_sev_validation.py +++ b/tests/test_sev_validation.py @@ -1,12 +1,12 @@ import pytest -from tinfoil.attestation.validate import ( +from tinfoil.attestation.validate_sev import ( ValidationOptions, validate_report, _validate_policy, _compare_policy_versions, _validate_platform_info ) -from tinfoil.attestation.abi_sevsnp import ( +from tinfoil.attestation.abi_sev import ( Report, SnpPolicy, SnpPlatformInfo, @@ -14,7 +14,7 @@ SignerInfo, ReportSigner ) -from tinfoil.attestation.verify import CertificateChain +from tinfoil.attestation.verify_sev import CertificateChain class TestValidationOptions: From f4e06682fc37e6dde7b7dfe8d6a1f2e8601b0fc3 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:15:11 -0500 Subject: [PATCH 09/19] feat: add shared attestation types and helpers (types.py) Extract Measurement, Verification, GroundTruth, and error classes into shared types module. Add safe_gzip_decompress with bomb protection and proper ValueError wrapping, fingerprint() with consistent SHA-256 return type, and named constants for register indices and key sizes. --- src/tinfoil/attestation/types.py | 199 +++++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 src/tinfoil/attestation/types.py diff --git a/src/tinfoil/attestation/types.py b/src/tinfoil/attestation/types.py new file mode 100644 index 0000000..f8a4186 --- /dev/null +++ b/src/tinfoil/attestation/types.py @@ -0,0 +1,199 @@ +""" +Shared types, errors, and protocol constants for attestation. + +This module is the canonical source for types used across TDX and SEV +attestation modules. It has no intra-package dependencies, so any module +can import from it without risk of circular imports. +""" + +import hashlib +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + + +# ============================================================================= +# Protocol-level constants (shared across TDX and SEV) +# ============================================================================= + +TLS_KEY_FP_SIZE = 32 # SHA-256 TLS public key fingerprint (bytes) +HPKE_KEY_SIZE = 32 # HPKE public key (bytes) + +# RTMR3 should always be zeros (48 bytes = 96 hex chars) +RTMR3_ZERO = "0" * 96 + +# Register layout constants per platform +TDX_REGISTER_COUNT = 5 # [mrtd, rtmr0, rtmr1, rtmr2, rtmr3] +TDX_MRTD_IDX = 0 +TDX_RTMR0_IDX = 1 +TDX_RTMR1_IDX = 2 +TDX_RTMR2_IDX = 3 +TDX_RTMR3_IDX = 4 +SEV_REGISTER_COUNT = 1 # [snp_measurement] +MULTIPLATFORM_REGISTER_COUNT = 3 # [snp_measurement, rtmr1, rtmr2] +MULTIPLATFORM_SNP_IDX = 0 +MULTIPLATFORM_RTMR1_IDX = 1 +MULTIPLATFORM_RTMR2_IDX = 2 + +# Shared decompression constants +MAX_DECOMPRESSED_SIZE = 10 * 1024 * 1024 # 10 MiB + + +def safe_gzip_decompress(data: bytes, max_size: int = MAX_DECOMPRESSED_SIZE) -> bytes: + """Decompress gzip data with a size limit to prevent gzip bombs. + + Args: + data: Gzip-compressed bytes + max_size: Maximum allowed decompressed size + + Returns: + Decompressed bytes + + Raises: + ValueError: If decompressed data exceeds max_size or decompression fails + """ + import gzip + import io + + try: + with gzip.GzipFile(fileobj=io.BytesIO(data)) as f: + result = f.read(max_size + 1) + except (OSError, EOFError) as e: + raise ValueError(f"Gzip decompression failed: {e}") from e + if len(result) > max_size: + raise ValueError( + f"Decompressed attestation exceeds maximum size ({max_size} bytes)" + ) + return result + + +# ============================================================================= +# Predicate types +# ============================================================================= + +class PredicateType(str, Enum): + """Predicate types for attestation""" + SEV_GUEST_V2 = "https://tinfoil.sh/predicate/sev-snp-guest/v2" + TDX_GUEST_V2 = "https://tinfoil.sh/predicate/tdx-guest/v2" + SNP_TDX_MULTIPLATFORM_v1 = "https://tinfoil.sh/predicate/snp-tdx-multiplatform/v1" + HARDWARE_MEASUREMENTS_V1 = "https://tinfoil.sh/predicate/hardware-measurements/v1" + + +TDX_TYPES = (PredicateType.TDX_GUEST_V2,) + + +# ============================================================================= +# Errors +# ============================================================================= + +class AttestationError(Exception): + """Base class for attestation errors""" + pass + +class FormatMismatchError(AttestationError): + """Raised when attestation formats don't match""" + pass + +class MeasurementMismatchError(AttestationError): + """Raised when measurements don't match""" + pass + +class Rtmr3NotZeroError(AttestationError): + """Raised when RTMR3 is not zeros""" + pass + +class HardwareMeasurementError(AttestationError): + """Raised when hardware measurement verification fails""" + pass + + +# ============================================================================= +# Data types +# ============================================================================= + +@dataclass +class HardwareMeasurement: + """Represents hardware platform measurements (MRTD and RTMR0 for TDX)""" + id: str # platform@digest + mrtd: str + rtmr0: str + +@dataclass +class Measurement: + """Represents measurement data""" + type: PredicateType + registers: List[str] + + def fingerprint(self) -> str: + """ + Computes the SHA-256 hash of the predicate type and all measurement + registers. Always returns a 64-char hex digest regardless of the + number of registers, so callers get a uniform format. + """ + if not self.registers: + raise ValueError("Cannot compute fingerprint: no measurement registers") + + all_data = self.type.value + "".join(self.registers) + return hashlib.sha256(all_data.encode()).hexdigest() + + def assert_equal(self, other: 'Measurement') -> None: + """ + Checks if this measurement equals another measurement with multi-platform support + Raises appropriate error if they don't match + """ + # Direct comparison for same types + if self.type == other.type: + if len(self.registers) != len(other.registers) or self.registers != other.registers: + raise MeasurementMismatchError() + return + + # Multi-platform comparison support + if self.type == PredicateType.SNP_TDX_MULTIPLATFORM_v1: + if other.type == PredicateType.TDX_GUEST_V2 and len(other.registers) == TDX_REGISTER_COUNT: + if (len(self.registers) != MULTIPLATFORM_REGISTER_COUNT or + self.registers[MULTIPLATFORM_RTMR1_IDX] != other.registers[TDX_RTMR1_IDX] or + self.registers[MULTIPLATFORM_RTMR2_IDX] != other.registers[TDX_RTMR2_IDX]): + raise MeasurementMismatchError() + if other.registers[TDX_RTMR3_IDX] != RTMR3_ZERO: + raise Rtmr3NotZeroError(f"RTMR3 must be zeros, got {other.registers[TDX_RTMR3_IDX]}") + return + elif other.type == PredicateType.SEV_GUEST_V2 and len(other.registers) == SEV_REGISTER_COUNT: + if (len(self.registers) != MULTIPLATFORM_REGISTER_COUNT or + self.registers[MULTIPLATFORM_SNP_IDX] != other.registers[0]): + raise MeasurementMismatchError() + return + + # Reverse comparisons + if other.type == PredicateType.SNP_TDX_MULTIPLATFORM_v1: + try: + other.assert_equal(self) + return + except (FormatMismatchError, MeasurementMismatchError): + raise + + # If we get here, the formats are incompatible + raise FormatMismatchError() + + def __str__(self) -> str: + """Returns a human-readable string representation of the measurement""" + if self.type == PredicateType.SEV_GUEST_V2 and len(self.registers) == SEV_REGISTER_COUNT: + return f"Measurement(type={self.type.value}, snp_measurement={self.registers[0][:16]}...)" + + elif self.type == PredicateType.TDX_GUEST_V2 and len(self.registers) == TDX_REGISTER_COUNT: + labels = ["mrtd", "rtmr0", "rtmr1", "rtmr2", "rtmr3"] + parts = [f"{label}={reg[:16]}..." for label, reg in zip(labels, self.registers)] + return f"Measurement(type={self.type.value}, {', '.join(parts)})" + + elif self.type == PredicateType.SNP_TDX_MULTIPLATFORM_v1 and len(self.registers) == MULTIPLATFORM_REGISTER_COUNT: + labels = ["snp_measurement", "rtmr1", "rtmr2"] + parts = [f"{label}={reg[:16]}..." for label, reg in zip(labels, self.registers)] + return f"Measurement(type={self.type.value}, {', '.join(parts)})" + + return f"Measurement(type={self.type.value}, registers={len(self.registers)} items)" + +@dataclass +class Verification: + """Represents verification results""" + measurement: Measurement + public_key_fp: str + hpke_public_key: Optional[str] = None From bb1d63de937ec64099e2678b95e79418b7295690 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:15:16 -0500 Subject: [PATCH 10/19] feat: add Sigstore DSSE verification for hardware measurements Extract shared DSSE verification helper, add multiplatform TDX/SEV hardware measurement fetching from GitHub via Sigstore transparency log. Supports TDX_GUEST_V2 and SEV_GUEST_V2 predicate types. --- src/tinfoil/sigstore.py | 179 +++++++++++++++++++++++++++++++--------- 1 file changed, 139 insertions(+), 40 deletions(-) diff --git a/src/tinfoil/sigstore.py b/src/tinfoil/sigstore.py index dfd7446..495c0b2 100644 --- a/src/tinfoil/sigstore.py +++ b/src/tinfoil/sigstore.py @@ -5,7 +5,9 @@ import json import re -from .attestation import Measurement, PredicateType +from typing import List +from .attestation import Measurement, PredicateType, HardwareMeasurement +from .github import fetch_latest_digest, fetch_attestation_bundle OIDC_ISSUER = "https://token.actions.githubusercontent.com" @@ -31,6 +33,54 @@ def verify(self, cert: Certificate) -> None: f"({_OIDC_GITHUB_WORKFLOW_REF_OID.dotted_string}) extension" ) + +def _verify_dsse_bundle(bundle_json: bytes, digest: str, repo: str) -> dict: + """ + Verify a Sigstore DSSE bundle and return the parsed in-toto payload. + + Performs signature verification, certificate identity policy checks, + Rekor log consistency, payload type validation, and digest matching. + + Args: + bundle_json: Raw Sigstore bundle JSON + digest: Expected SHA256 hex digest of the DSSE payload subject + repo: GitHub repository (e.g. "tinfoilsh/confidential-router") + + Returns: + Parsed in-toto statement dict with predicateType, predicate, subject, etc. + + Raises: + ValueError: If any verification step fails + """ + verifier = Verifier.production() + bundle = Bundle.from_json(bundle_json) + + policy = AllOf([ + OIDCIssuer(OIDC_ISSUER), + GitHubWorkflowRepository(repo), + GitHubWorkflowRefPattern("refs/tags/.*") + ]) + + payload_type, payload_bytes = verifier.verify_dsse(bundle, policy) + + if payload_type != 'application/vnd.in-toto+json': + raise ValueError(f"Unsupported payload type: {payload_type}") + + statement = json.loads(payload_bytes) + + subjects = statement.get("subject", []) + if not subjects or "digest" not in subjects[0] or "sha256" not in subjects[0].get("digest", {}): + raise ValueError("Invalid in-toto statement: missing or empty subject") + + if digest != subjects[0]["digest"]["sha256"]: + raise ValueError( + f"Digest mismatch: expected {digest}, " + f"got {subjects[0]['digest']['sha256']}" + ) + + return statement + + def verify_attestation(bundle_json: bytes, digest: str, repo: str) -> Measurement: """ Verifies the attested measurements of an enclave image against a trusted root (Sigstore) @@ -48,45 +98,26 @@ def verify_attestation(bundle_json: bytes, digest: str, repo: str) -> Measuremen ValueError: If verification fails or digests don't match """ try: - # Create verifier with the trusted root - verifier = Verifier.production() - - # Parse the bundle - bundle = Bundle.from_json(bundle_json) - - # Create verification policy for GitHub Actions certificate identity - policy = AllOf([ - OIDCIssuer(OIDC_ISSUER), - GitHubWorkflowRepository(repo), - GitHubWorkflowRefPattern("refs/tags/.*") - ]) - - # --- Core DSSE Verification --- - # This verifies the signature on the DSSE envelope, applies the - # certificate identity policy, and checks Rekor log consistency. - # It returns the verified payload from within the envelope. - payload_type, payload_bytes = verifier.verify_dsse(bundle, policy) - - # --- Process the Verified Payload --- - if payload_type != 'application/vnd.in-toto+json': - raise ValueError(f"Unsupported payload type: {payload_type}. Only supports In-toto.") - - result_json = json.loads(payload_bytes) - predicate_type = PredicateType(result_json["predicateType"]) - predicate_fields = result_json["predicate"] - - # --- Manual Payload Digest Verification --- - # Now, verify that the provided external digest matches the - # actual digest in the payload returned from the verified envelope. - if digest != result_json["subject"][0]["digest"]["sha256"]: - raise ValueError( - f"Provided digest does not match verified DSSE payload digest. " - f"Expected: {digest}, Got: {result_json['subject'][0]['digest']['sha256']}" - ) - - # Convert predicate type to measurement type + statement = _verify_dsse_bundle(bundle_json, digest, repo) + + predicate_type = PredicateType(statement["predicateType"]) + predicate_fields = statement["predicate"] + if predicate_type == PredicateType.SNP_TDX_MULTIPLATFORM_v1: - registers = [predicate_fields["snp_measurement"]] + snp_measurement = predicate_fields.get("snp_measurement") + if not snp_measurement: + raise ValueError("Invalid multiplatform measurement: no snp_measurement") + + tdx_measurement = predicate_fields.get("tdx_measurement") + if not tdx_measurement: + raise ValueError("Invalid multiplatform measurement: no tdx_measurement") + + rtmr1 = tdx_measurement.get("rtmr1") + rtmr2 = tdx_measurement.get("rtmr2") + if not rtmr1 or not rtmr2: + raise ValueError("Invalid multiplatform measurement: missing rtmr1 or rtmr2") + + registers = [snp_measurement, rtmr1, rtmr2] else: raise ValueError(f"Unsupported predicate type: {predicate_type}") @@ -94,6 +125,74 @@ def verify_attestation(bundle_json: bytes, digest: str, repo: str) -> Measuremen type=predicate_type, registers=registers ) - + except Exception as e: raise ValueError(f"Attestation processing failed: {e}") from e + + +HARDWARE_MEASUREMENTS_REPO = "tinfoilsh/hardware-measurements" + + +def verify_hardware_measurements(bundle_json: bytes, digest: str, repo: str) -> List[HardwareMeasurement]: + """ + Verifies hardware measurements from a Sigstore bundle. + + Args: + bundle_json: The bundle JSON data (bytes) + digest: The expected hex-encoded SHA256 digest + repo: The repository name + + Returns: + List of HardwareMeasurement objects + + Raises: + ValueError: If verification fails or predicate type is unexpected + """ + try: + statement = _verify_dsse_bundle(bundle_json, digest, repo) + + predicate_type = statement["predicateType"] + if predicate_type != PredicateType.HARDWARE_MEASUREMENTS_V1.value: + raise ValueError(f"Unexpected predicate type: {predicate_type}") + + predicate_fields = statement["predicate"] + measurements = [] + + for platform_id, platform_data in predicate_fields.items(): + if not isinstance(platform_data, dict): + raise ValueError(f"Invalid hardware measurement for {platform_id}") + + mrtd = platform_data.get("mrtd") + rtmr0 = platform_data.get("rtmr0") + + if not mrtd or not rtmr0: + raise ValueError(f"Invalid hardware measurement for {platform_id}: missing mrtd or rtmr0") + + measurements.append(HardwareMeasurement( + id=f"{platform_id}@{digest}", + mrtd=mrtd, + rtmr0=rtmr0, + )) + + return measurements + + except Exception as e: + raise ValueError(f"Hardware measurements processing failed: {e}") from e + + +def fetch_latest_hardware_measurements() -> List[HardwareMeasurement]: + """ + Fetches the latest hardware measurements from GitHub + Sigstore. + + Returns: + List of HardwareMeasurement objects + + Raises: + ValueError: If fetching or verification fails + """ + try: + digest = fetch_latest_digest(HARDWARE_MEASUREMENTS_REPO) + bundle_json = fetch_attestation_bundle(HARDWARE_MEASUREMENTS_REPO, digest) + return verify_hardware_measurements(bundle_json, digest, HARDWARE_MEASUREMENTS_REPO) + except Exception as e: + raise ValueError(f"Hardware measurements fetching failed: {e}") from e From 98b953e382c3c5592998918550b8921bcd98d738 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:15:22 -0500 Subject: [PATCH 11/19] feat: add secure httpx client with TLS pinning Add make_secure_http_client and make_secure_async_http_client with certificate fingerprint pinning. Integrate TdxVerificationConfig, platform-filtered router selection with URL-encoded query params, and TDX hardware measurement verification in the verify() flow. --- src/tinfoil/client.py | 61 +++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/src/tinfoil/client.py b/src/tinfoil/client.py index 124baa1..8462b82 100644 --- a/src/tinfoil/client.py +++ b/src/tinfoil/client.py @@ -1,19 +1,21 @@ import http.client import json import ssl +import urllib.error import urllib.request import httpx import random from dataclasses import dataclass from typing import Dict, Optional -from urllib.parse import urlparse +from urllib.parse import urlparse, urlencode import cryptography.x509 from cryptography.hazmat.primitives.serialization import PublicFormat, Encoding import hashlib -from .attestation import fetch_attestation +from .attestation import fetch_attestation, TDX_TYPES +from .attestation.attestation_tdx import verify_tdx_hardware from .github import fetch_latest_digest, fetch_attestation_bundle -from .sigstore import verify_attestation +from .sigstore import verify_attestation, fetch_latest_hardware_measurements @dataclass @@ -114,14 +116,14 @@ def wrap_socket(*args, **kwargs) -> ssl.SSLSocket: sock = ssl.create_default_context().wrap_socket(*args, **kwargs) cert_binary = sock.getpeercert(binary_form=True) if not cert_binary: - raise Exception("No certificate found") + raise ValueError("No certificate found") cert = cryptography.x509.load_der_x509_certificate(cert_binary) pub_der = cert.public_key().public_bytes( Encoding.DER, PublicFormat.SubjectPublicKeyInfo ) pk_fp = hashlib.sha256(pub_der).hexdigest() if pk_fp != expected_fp: - raise Exception(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}") + raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}") return sock return wrap_socket @@ -152,8 +154,9 @@ def verify(self) -> GroundTruth: if expected_snp_measurement is None: raise ValueError("snp_measurement not found in provided measurement") - # Get the actual measurement from the attestation - actual_measurement = verification.measurement.registers[0] # SNP measurement is the first register + if not verification.measurement.registers: + raise ValueError("No measurement registers found in attestation") + actual_measurement = verification.measurement.registers[0] if actual_measurement != expected_snp_measurement: raise ValueError(f"SNP measurement mismatch: expected {expected_snp_measurement}, got {actual_measurement}") @@ -169,17 +172,20 @@ def verify(self) -> GroundTruth: # GitHub-based verification digest = fetch_latest_digest(self.repo) sigstore_bundle = fetch_attestation_bundle(self.repo, digest) - + code_measurements = verify_attestation( - sigstore_bundle, - digest, + sigstore_bundle, + digest, self.repo ) - - # Verify measurements match - for (i, code_measurement) in enumerate(code_measurements.registers): - if code_measurement != verification.measurement.registers[i]: - raise ValueError("Code measurements do not match") + + # For TDX, verify hardware measurements (MRTD and RTMR0) + if verification.measurement.type in TDX_TYPES: + hardware_measurements = fetch_latest_hardware_measurements() + verify_tdx_hardware(hardware_measurements, verification.measurement) + + # Verify measurements match (handles cross-platform comparison) + code_measurements.assert_equal(verification.measurement) # Build ground truth from the verified attestation self._ground_truth = GroundTruth( @@ -231,17 +237,26 @@ def get(self, url: str, headers: Dict[str, str] = {}) -> Response: ) return self.make_request(req) -def get_router_address() -> str: +def get_router_address(platform: Optional[str] = None) -> str: """ Fetches the list of available routers from the ATC API and returns a randomly selected address. + + Args: + platform: Optional platform filter (e.g. "snp", "tdx"). + If None, returns routers for any platform. """ + routers_url = "https://atc.tinfoil.sh/routers" + if platform: + routers_url += "?" + urlencode({"platform": platform}) - routers_url = "https://atc.tinfoil.sh/routers?platform=snp" + try: + with urllib.request.urlopen(routers_url, timeout=15) as response: + routers = json.loads(response.read().decode('utf-8')) + except (urllib.error.URLError, json.JSONDecodeError) as e: + raise ValueError(f"Failed to fetch router addresses: {e}") from e - with urllib.request.urlopen(routers_url) as response: - routers = json.loads(response.read().decode('utf-8')) - if len(routers) == 0: - raise ValueError("No routers found in the response") - - return random.choice(routers) + if not isinstance(routers, list) or len(routers) == 0: + raise ValueError("No routers found in the response") + + return random.choice(routers) From 37de88c3740fd06596909b63a914e2c7cf176d83 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:15:28 -0500 Subject: [PATCH 12/19] refactor: update attestation module exports and format dispatch Update __init__.py exports for new TDX/SEV module structure, add TDX format detection in attestation.py, remove deprecated V1 types, and handle empty release bodies in github.py. --- src/tinfoil/attestation/__init__.py | 34 +++- src/tinfoil/attestation/attestation.py | 208 +++---------------------- src/tinfoil/github.py | 32 ++-- 3 files changed, 74 insertions(+), 200 deletions(-) diff --git a/src/tinfoil/attestation/__init__.py b/src/tinfoil/attestation/__init__.py index 5697fd3..7e18196 100644 --- a/src/tinfoil/attestation/__init__.py +++ b/src/tinfoil/attestation/__init__.py @@ -1,17 +1,39 @@ +from .types import ( + PredicateType, + TDX_TYPES, + Measurement, + Verification, + HardwareMeasurement, + AttestationError, + FormatMismatchError, + MeasurementMismatchError, + Rtmr3NotZeroError, + HardwareMeasurementError, + RTMR3_ZERO, +) from .attestation import ( fetch_attestation, verify_attestation_json, - verify_sev_attestation_v2, - Measurement, - PredicateType, - from_snp_digest ) +from .attestation_tdx import verify_tdx_attestation_v2, TdxAttestationError, verify_tdx_hardware +from .attestation_sev import verify_sev_attestation_v2, SevAttestationError __all__ = [ 'fetch_attestation', 'verify_sev_attestation_v2', + 'verify_tdx_attestation_v2', + 'verify_tdx_hardware', 'verify_attestation_json', 'Measurement', + 'Verification', 'PredicateType', - 'from_snp_digest' -] \ No newline at end of file + 'RTMR3_ZERO', + 'AttestationError', + 'FormatMismatchError', + 'MeasurementMismatchError', + 'Rtmr3NotZeroError', + 'HardwareMeasurementError', + 'HardwareMeasurement', + 'TdxAttestationError', + 'SevAttestationError', +] diff --git a/src/tinfoil/attestation/attestation.py b/src/tinfoil/attestation/attestation.py index 2de3c12..7e2f3b5 100644 --- a/src/tinfoil/attestation/attestation.py +++ b/src/tinfoil/attestation/attestation.py @@ -1,76 +1,24 @@ -from dataclasses import dataclass -from enum import Enum -import json - -import base64 -import gzip import hashlib +import json import ssl -from typing import List, Optional +from dataclasses import dataclass + import requests from cryptography import x509 from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ec -from .validate import validate_report, ValidationOptions -from .verify import Report, verify_attestation, CertificateChain -from .abi_sevsnp import TCBParts, SnpPolicy, SnpPlatformInfo - - -class PredicateType(str, Enum): - """Predicate types for attestation""" - SEV_GUEST_V1 = "https://tinfoil.sh/predicate/sev-snp-guest/v1" # Deprecated - SEV_GUEST_V2 = "https://tinfoil.sh/predicate/sev-snp-guest/v2" - TDX_GUEST_V1 = "https://tinfoil.sh/predicate/tdx-guest/v1" # Deprecated - SNP_TDX_MULTIPLATFORM_v1 = "https://tinfoil.sh/predicate/snp-tdx-multiplatform/v1" +from .types import ( + PredicateType, + Measurement, + Verification, +) +from .attestation_tdx import verify_tdx_attestation_v2 +from .attestation_sev import verify_sev_attestation_v2 ATTESTATION_ENDPOINT = "/.well-known/tinfoil-attestation" +REQUEST_TIMEOUT_SECONDS = 15 -class AttestationError(Exception): - """Base class for attestation errors""" - pass - -class FormatMismatchError(AttestationError): - """Raised when attestation formats don't match""" - pass - -class MeasurementMismatchError(AttestationError): - """Raised when measurements don't match""" - pass - -@dataclass -class Measurement: - """Represents measurement data""" - type: PredicateType - registers: List[str] - - def fingerprint(self) -> str: - """ - Computes the SHA-256 hash of all measurements, - or returns the single measurement if there is only one - """ - if len(self.registers) == 1: - return self.registers[0] - - all_data = str(self.type) + "".join(self.registers) - return hashlib.sha256(all_data.encode()).hexdigest() - - def equals(self, other: 'Measurement') -> None: - """ - Checks if this measurement equals another measurement - Raises appropriate error if they don't match - """ - if self.type != other.type: - raise FormatMismatchError() - if len(self.registers) != len(other.registers) or self.registers != other.registers: - raise MeasurementMismatchError() - -@dataclass -class Verification: - """Represents verification results""" - measurement: Measurement - public_key_fp: str - hpke_public_key: Optional[str] = None @dataclass class Document: @@ -90,12 +38,21 @@ def verify(self) -> Verification: """ if self.format == PredicateType.SEV_GUEST_V2: return verify_sev_attestation_v2(self.body) + elif self.format == PredicateType.TDX_GUEST_V2: + return verify_tdx_attestation_v2(self.body) else: raise ValueError(f"Unsupported attestation format: {self.format}") def verify_attestation_json(json_data: bytes) -> Verification: """Verifies an attestation document in JSON format and returns the inner measurements""" - doc_dict = json.loads(json_data) + try: + doc_dict = json.loads(json_data) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise ValueError(f"Invalid attestation JSON: {e}") from e + + if not isinstance(doc_dict, dict) or "format" not in doc_dict or "body" not in doc_dict: + raise ValueError("Attestation JSON must contain 'format' and 'body' fields") + doc = Document( format=PredicateType(doc_dict["format"]), body=doc_dict["body"] @@ -130,129 +87,14 @@ def connection_cert_fp(ssl_socket: ssl.SSLSocket) -> str: def fetch_attestation(host: str) -> Document: """Retrieves the attestation document from a given enclave hostname""" url = f"https://{host}{ATTESTATION_ENDPOINT}" - response = requests.get(url) + response = requests.get(url, timeout=REQUEST_TIMEOUT_SECONDS) response.raise_for_status() doc_dict = response.json() + if not isinstance(doc_dict, dict) or "format" not in doc_dict or "body" not in doc_dict: + raise ValueError(f"Invalid attestation response from {host}: missing 'format' or 'body'") + return Document( format=PredicateType(doc_dict["format"]), body=doc_dict["body"] ) - -min_tcb = TCBParts( - bl_spl=0x7, - tee_spl=0, - snp_spl=0xe, - ucode_spl=0x48, -) - -default_validation_options = ValidationOptions( - guest_policy=SnpPolicy( - abi_minor=0, - abi_major=0, - smt=True, - migrate_ma=False, - debug=False, - single_socket=False, - cxl_allowed=False, - mem_aes256_xts=False, - rapl_dis=False, - ciphertext_hiding_dram=False, - page_swap_disabled=False, - ), - minimum_guest_svn=0, - minimum_build=21, - minimum_version=(1 << 8) | 55, # 1.55 - minimum_tcb=min_tcb, - minimum_launch_tcb=min_tcb, - permit_provisional_firmware=False, # We only support False per your requirement - platform_info=SnpPlatformInfo( - smt_enabled=True, - tsme_enabled=True, - ecc_enabled=False, - rapl_disabled=False, - ciphertext_hiding_dram_enabled=False, - alias_check_complete=False, - tio_enabled=False, - ), - require_author_key=False, - require_id_block=False, -) - -def verify_sev_attestation_v2(attestation_doc: str) -> Verification: - """Verify SEV attestation document and return verification result.""" - report = verify_sev_report(attestation_doc, True) - - # Create measurement object - measurement = Measurement( - type=PredicateType.SEV_GUEST_V2, - registers=[ - report.measurement.hex() - ] - ) - - keys = report.report_data - tls_key_fp = keys[0:32] - hpke_public_key = keys[32:64] - - return Verification( - measurement=measurement, - public_key_fp=tls_key_fp.hex(), - hpke_public_key=hpke_public_key.hex() - ) - - -def verify_sev_report(attestation_doc: str, is_compressed: bool) -> Report: - """Verify SEV attestation document and return verification result.""" - try: - att_doc_bytes = base64.b64decode(attestation_doc) - except Exception as e: - raise ValueError(f"Failed to decode base64: {e}") - - if is_compressed: - att_doc_bytes = gzip.decompress(att_doc_bytes) - - # Parse the report - try: - report = Report(att_doc_bytes) - except Exception as e: - raise ValueError(f"Failed to parse report: {e}") - - # Get attestation chain - chain: CertificateChain = CertificateChain.from_report(report) - - # Verify attestation - try: - res = verify_attestation(chain, report) - except Exception as e: - raise ValueError(f"Failed to verify attestation: {e}") - - if not res: - raise ValueError("Attestation verification failed!") - - # Validate report - try: - validate_report(report, chain, default_validation_options) - except Exception as e: - raise ValueError(f"Failed to validate report: {e}") - - return report - -def from_snp_digest(snp_digest: str) -> dict: - """ - Convert an SNP launch digest string to measurement format. - - Args: - snp_digest: The SNP launch digest as a hex string - - Returns: - Dictionary in the format expected by SecureClient measurement parameter - - Example: - from tinfoil.attestation import from_snp_digest - measurement = from_snp_digest("abcdef") - client = TinfoilAI(measurement=measurement) - """ - return { - "snp_measurement": snp_digest - } diff --git a/src/tinfoil/github.py b/src/tinfoil/github.py index c46db23..f263528 100644 --- a/src/tinfoil/github.py +++ b/src/tinfoil/github.py @@ -8,8 +8,14 @@ # --- Cache Setup --- _GITHUB_CACHE_DIR = platformdirs.user_cache_dir("tinfoil", "tinfoil") -# Create a subdirectory specific to GitHub attestations within the main cache -os.makedirs(_GITHUB_CACHE_DIR, exist_ok=True) + + +def _ensure_github_cache_dir() -> None: + """Create the GitHub cache directory if it doesn't exist. Silently ignores failures.""" + try: + os.makedirs(_GITHUB_CACHE_DIR, exist_ok=True) + except OSError: + pass def _attestation_bundle_cache_path(repo: str, digest: str) -> str: """Generate a safe filepath for the attestation bundle cache.""" @@ -33,12 +39,14 @@ def fetch_latest_digest(repo: str) -> str: Exception: If there's any error fetching or parsing the data """ url = f"https://api-github-proxy.tinfoil.sh/repos/{repo}/releases/latest" - release_response = requests.get(url) + release_response = requests.get(url, timeout=15) release_response.raise_for_status() response_data = json.loads(release_response.content) + if not isinstance(response_data, dict) or "tag_name" not in response_data: + raise ValueError(f"Unexpected release API response for {repo}: missing 'tag_name'") tag_name = response_data["tag_name"] - body = response_data["body"] + body = response_data.get("body") or "" # Backwards compatibility for old EIF releases eif_regex = re.compile(r'EIF hash: ([a-fA-F0-9]{64})') @@ -54,7 +62,7 @@ def fetch_latest_digest(repo: str) -> str: # Fallback option: fetch digest from github special endpoint digest_url = f"https://github-proxy.tinfoil.sh/{repo}/releases/download/{tag_name}/tinfoil.hash" - response = requests.get(digest_url) + response = requests.get(digest_url, timeout=15) if response.status_code != 200: raise Exception(f"Failed to fetch attestation digest: {response.status_code} {response.reason}") return response.text.strip() @@ -115,15 +123,17 @@ def fetch_attestation_bundle(repo: str, digest: str) -> bytes: # Encode the string to bytes for file writing and return value bundle_bytes_to_write = bundle_json_string.encode('utf-8') - # Write to cache + # Atomic write to cache + _ensure_github_cache_dir() try: - with open(cache_path, 'wb') as f: + tmp_path = cache_path + ".tmp" + with open(tmp_path, 'wb') as f: f.write(bundle_bytes_to_write) - except OSError as e: - # Don't fail the whole operation if cache write fails, just warn - print(f"Warning: Failed to write cache file {cache_path}: {e}", file=sys.stderr) + os.replace(tmp_path, cache_path) + except OSError: + pass - return bundle_json_string + return bundle_bytes_to_write except (KeyError, IndexError, TypeError) as e: raise Exception(f"Invalid attestation response format from {url}: {e}. Response: {response_data}") from e From ddd5e7ece22b4250fdfdbbd0204ae0e008d92eaa Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:15:33 -0500 Subject: [PATCH 13/19] test: add integration and multi-enclave attestation tests Add parameterized all-enclaves integration test with graceful handling of transient network errors during test collection. Update existing attestation flow tests for new module structure. --- tests/test_attestation_all_enclaves.py | 189 +++++++++++++++++++++++++ tests/test_attestation_flow.py | 156 ++++++++++---------- 2 files changed, 271 insertions(+), 74 deletions(-) create mode 100644 tests/test_attestation_all_enclaves.py diff --git a/tests/test_attestation_all_enclaves.py b/tests/test_attestation_all_enclaves.py new file mode 100644 index 0000000..1ee8700 --- /dev/null +++ b/tests/test_attestation_all_enclaves.py @@ -0,0 +1,189 @@ +""" +Integration test for all live enclaves from the config file. + +Fetches config from GitHub, extracts all model enclaves, and verifies +attestation for each one. + +Configure via environment variables: + TINFOIL_CONFIG_URL - URL to config.yml (default: main branch) + TINFOIL_API_KEY - API key for model tests (optional) + +Example: + python -m pytest tests/test_all_enclaves.py -v -s + python -m pytest tests/test_all_enclaves.py -v -s -k "llama" # filter by model name +""" + +import os +import pytest +import requests + +try: + import yaml + HAS_YAML = True +except ImportError: + HAS_YAML = False + +from tinfoil.client import SecureClient +from tinfoil.attestation import PredicateType + +pytestmark = pytest.mark.integration + +# Config URL +CONFIG_URL = os.environ.get( + "TINFOIL_CONFIG_URL", + "https://raw.githubusercontent.com/tinfoilsh/confidential-model-router/refs/heads/main/config.yml" +) + +# Models to skip (not testable) +SKIP_MODELS = ["websearch"] + + +def fetch_config() -> dict: + """Fetch and parse the config file.""" + if not HAS_YAML: + raise ImportError("PyYAML not installed. Install with: pip install pyyaml") + response = requests.get(CONFIG_URL, timeout=30) + response.raise_for_status() + return yaml.safe_load(response.text) + + +def get_all_enclaves() -> list[tuple[str, str, str]]: + """ + Get all enclaves from config. + + Returns: + List of (model_name, repo, hostname) tuples + + Raises: + ImportError: If PyYAML is not installed + Exception: If the config cannot be fetched or parsed + """ + if not HAS_YAML: + raise ImportError("PyYAML not installed. Install with: pip install pyyaml") + + config = fetch_config() + + enclaves = [] + models = config.get("models", {}) + + for model_name, model_config in models.items(): + if any(skip.lower() in model_name.lower() for skip in SKIP_MODELS): + continue + + repo = model_config.get("repo", "") + hostnames = model_config.get("enclaves", []) or model_config.get("hostnames", []) + + if not repo or not hostnames: + continue + + for hostname in hostnames: + enclaves.append((model_name, repo, hostname)) + + return enclaves + + +def pytest_generate_tests(metafunc): + """Defer network call to test generation time instead of module import.""" + if "enclave_config" in metafunc.fixturenames: + try: + enclaves = get_all_enclaves() + except ImportError: + enclaves = [] + except Exception: + enclaves = [] + + if not enclaves: + enclaves = [("__skip__", "__skip__", "__skip__")] + ids = ["no_enclaves"] + else: + ids = [f"{n}@{h}" for n, _, h in enclaves] + + metafunc.parametrize("enclave_config", enclaves, ids=ids) + + +def test_enclave_attestation(enclave_config): + """ + Test attestation for a single enclave. + + Verifies: + 1. Can connect to enclave + 2. Attestation verification passes (crypto + policy) + 3. Sigstore verification passes + 4. Measurements match + 5. For TDX: hardware measurements verified + """ + model_name, repo, hostname = enclave_config + + if model_name == "__skip__": + if not HAS_YAML: + pytest.skip("PyYAML not installed. Install with: pip install pyyaml") + else: + pytest.skip("No enclaves found in config") + + print(f"\n{'='*60}") + print(f"Testing: {model_name}") + print(f" Enclave: {hostname}") + print(f" Repo: {repo}") + print(f"{'='*60}") + + try: + client = SecureClient(enclave=hostname, repo=repo) + ground_truth = client.verify() + + # Print results + measurement_type = ground_truth.measurement.type + print(f"\n✓ Attestation verified!") + print(f" Architecture: {measurement_type.value}") + print(f" Fingerprint: {ground_truth.measurement.fingerprint()[:32]}...") + print(f" Public key: {ground_truth.public_key[:32]}...") + print(f" Digest: {ground_truth.digest[:32]}...") + + # Print architecture-specific info + regs = ground_truth.measurement.registers + if measurement_type == PredicateType.SEV_GUEST_V2: + print(f" SNP measurement: {regs[0][:32]}...") + elif measurement_type == PredicateType.TDX_GUEST_V2: + print(f" MRTD: {regs[0][:32]}...") + print(f" RTMR0: {regs[1][:32]}...") + + except Exception as e: + pytest.fail(f"Attestation failed for {model_name}@{hostname}: {e}") + + +def test_summary(): + """Print summary of all enclaves that will be tested.""" + if not HAS_YAML: + pytest.skip("PyYAML not installed. Install with: pip install pyyaml") + + try: + enclaves = get_all_enclaves() + except Exception as e: + pytest.fail(f"Failed to fetch enclave config: {e}") + + if not enclaves: + pytest.skip("No enclaves found in config") + + print(f"\n{'='*60}") + print(f"ENCLAVE SUMMARY") + print(f"{'='*60}") + print(f"Config: {CONFIG_URL}") + print(f"Total enclaves: {len(enclaves)}") + print(f"\nModels:") + + # Group by model + models = {} + for name, repo, host in enclaves: + if name not in models: + models[name] = {"repo": repo, "hosts": []} + models[name]["hosts"].append(host) + + for name, info in sorted(models.items()): + print(f"\n {name}:") + print(f" Repo: {info['repo']}") + print(f" Enclaves: {len(info['hosts'])}") + for host in info['hosts']: + print(f" - {host}") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_attestation_flow.py b/tests/test_attestation_flow.py index 2f8cf28..750c911 100644 --- a/tests/test_attestation_flow.py +++ b/tests/test_attestation_flow.py @@ -1,89 +1,97 @@ +""" +Integration test for attestation verification flow. + +Tests the complete verification using the SecureClient API with a live router. +Works with any architecture (SNP or TDX) returned by the router service. +""" + import pytest -# Adjust these imports based on your project structure -from tinfoil.github import fetch_latest_digest, fetch_attestation_bundle -from tinfoil.sigstore import verify_attestation -from tinfoil.attestation import fetch_attestation -from tinfoil.client import get_router_address +from tinfoil.client import SecureClient, get_router_address +from tinfoil.attestation import PredicateType pytestmark = pytest.mark.integration # allows pytest -m integration filtering -# Fetch config from environment variables, falling back to defaults -# Use the same env vars as the other integration test for consistency +# Router always runs confidential-model-router REPO = "tinfoilsh/confidential-model-router" + def test_full_verification_flow(): """ - Tests the complete attestation verification flow: - 1. Fetch latest digest for the repository. - 2. Fetch the sigstore attestation bundle for that digest. - 3. Verify the sigstore bundle to get code measurements. - 4. Fetch the runtime attestation from the enclave. - 5. Verify the runtime attestation. - 6. Compare code measurements with runtime measurements. + Tests the complete attestation verification flow using SecureClient. + + Gets a router from the ATC service and verifies it against the + confidential-model-router repo. Works with any TEE type (SNP or TDX). + + SecureClient.verify() performs: + 1. Fetch runtime attestation from enclave + 2. Verify attestation (cryptographic + policy validation) + 3. Fetch latest digest from GitHub + 4. Fetch and verify sigstore attestation bundle + 5. For TDX: verify hardware measurements (MRTD, RTMR0) + 6. Compare code measurements with runtime measurements """ try: - # Fetch enclave address lazily inside the test to avoid import-time network calls - try: - enclave = get_router_address() - except Exception as e: - pytest.skip(f"Could not fetch router address from ATC service: {e}") - return - - # Fetch latest release digest - print(f"Fetching latest release for {REPO}") - digest = fetch_latest_digest(REPO) - print(f"Found digest: {digest}") - - # Fetch attestation bundle - print(f"Fetching attestation bundle for {REPO}@{digest}") - sigstore_bundle = fetch_attestation_bundle(REPO, digest) - assert sigstore_bundle is not None # Basic check - - # Verify attested measurements from sigstore bundle - print(f"Verifying attested measurements for {REPO}@{digest}") - code_measurements = verify_attestation( - sigstore_bundle, - digest, - REPO - ) - assert code_measurements is not None # Basic check - print(f"Code measurements fingerprint: {code_measurements.fingerprint()}") - - - # Fetch runtime attestation from the enclave - print(f"Fetching runtime attestation from {enclave}") - enclave_attestation = fetch_attestation(enclave) - assert enclave_attestation is not None # Basic check - - # Verify enclave measurements from runtime attestation - print("Verifying enclave measurements") - runtime_verification = enclave_attestation.verify() - assert runtime_verification is not None # Basic check - print(f"Runtime measurement fingerprint: {runtime_verification.measurement.fingerprint()}") - print(f"Public key fingerprint: {runtime_verification.public_key_fp}") - - - # Compare measurements - print("Comparing measurements") - assert len(code_measurements.registers) == len(runtime_verification.measurement.registers), \ - "Number of measurement registers differ" - - for i, code_reg in enumerate(code_measurements.registers): - runtime_reg = runtime_verification.measurement.registers[i] - assert code_reg == runtime_reg, \ - f"Measurement register {i} mismatch: Code='{code_reg}' vs Runtime='{runtime_reg}'" - - print("Verification successful!") - print(f"Public key fingerprint: {runtime_verification.public_key_fp}") - print(f"Measurement: {code_measurements.fingerprint()}") - + enclave = get_router_address() except Exception as e: - import traceback - traceback.print_exc() - pytest.fail(f"Verification flow failed with exception: {e}") + pytest.skip(f"Could not fetch router address from ATC service: {e}") + + print(f"\nVerifying enclave: {enclave}") + print(f"Against repo: {REPO}") + + client = SecureClient(enclave=enclave, repo=REPO) + ground_truth = client.verify() + + # Print architecture-specific info + measurement_type = ground_truth.measurement.type + print(f"\n✓ Verification successful!") + print(f" Architecture: {measurement_type.value}") + print(f" Measurement fingerprint: {ground_truth.measurement.fingerprint()}") + print(f" Public key fingerprint: {ground_truth.public_key}") + print(f" Digest: {ground_truth.digest}") + + # Print registers based on type + regs = ground_truth.measurement.registers + if measurement_type == PredicateType.SEV_GUEST_V2: + print(f"\n SNP Measurement: {regs[0][:32]}...") + elif measurement_type == PredicateType.TDX_GUEST_V2: + print(f"\n TDX Measurements:") + print(f" MRTD: {regs[0][:32]}...") + print(f" RTMR0: {regs[1][:32]}...") + print(f" RTMR1: {regs[2][:32]}...") + print(f" RTMR2: {regs[3][:32]}...") + print(f" RTMR3: {regs[4][:32]}...") + + +def test_secure_http_client(): + """ + Tests that SecureClient creates a working pinned HTTP client + and that TLS pinning is exercised by issuing an actual request. + Works with any TEE type (SNP or TDX). + """ + try: + enclave = get_router_address() + except Exception as e: + pytest.skip(f"Could not fetch router address from ATC service: {e}") + + print(f"\nCreating secure HTTP client for: {enclave}") + + client = SecureClient(enclave=enclave, repo=REPO) + http_client = client.make_secure_http_client() + + ground_truth = client.ground_truth + assert ground_truth is not None + + try: + response = http_client.get(f"https://{enclave}/.well-known/tinfoil-attestation") + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + print(f"\n✓ TLS-pinned request succeeded (status {response.status_code})") + finally: + http_client.close() + + print(f" Architecture: {ground_truth.measurement.type.value}") + print(f" TLS pinned to: {ground_truth.public_key}") if __name__ == "__main__": - # Allow running the test directly using `python tests/test_verification_flow.py` - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) From 2428414fb6739f8cc82cbf1331a672d73d5861a5 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:15:38 -0500 Subject: [PATCH 14/19] chore: add pyasn1 dependency, update gitignore Add pyasn1>=0.4.0 as explicit dependency for ASN.1 parsing. Stop ignoring all .md files in gitignore. --- .gitignore | 8 +++++++- pyproject.toml | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2f8fa1f..0c558fb 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,10 @@ cython_debug/ tuf-repo-cdn.sigstore.dev.json verifier/ tinfoil/tinfoil_verifier/ -.DS_Store \ No newline at end of file +.DS_Store + +# Logs +*.log + +# Generated documentation +docs/_build/ diff --git a/pyproject.toml b/pyproject.toml index 0682a33..b1e6fa0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "requests>=2.31.0", "cryptography>=42.0.0", "pyOpenSSL>=25.0.0", + "pyasn1>=0.4.0", "sigstore>=4.1.0", "platformdirs>=4.2.0", "pytest-asyncio>=0.26.0" From 7fefe37eecea59099b72f2f32f96bc656e38a9cf Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:15:46 -0500 Subject: [PATCH 15/19] refactor: remove old SEV module files replaced by dedicated modules Remove abi_sevsnp.py (replaced by abi_sev.py), validate.py (replaced by validate_sev.py), and verify.py (replaced by verify_sev.py). --- src/tinfoil/attestation/abi_sevsnp.py | 449 ------------------------ src/tinfoil/attestation/validate.py | 288 ---------------- src/tinfoil/attestation/verify.py | 478 -------------------------- 3 files changed, 1215 deletions(-) delete mode 100644 src/tinfoil/attestation/abi_sevsnp.py delete mode 100644 src/tinfoil/attestation/validate.py delete mode 100644 src/tinfoil/attestation/verify.py diff --git a/src/tinfoil/attestation/abi_sevsnp.py b/src/tinfoil/attestation/abi_sevsnp.py deleted file mode 100644 index cb9e079..0000000 --- a/src/tinfoil/attestation/abi_sevsnp.py +++ /dev/null @@ -1,449 +0,0 @@ -from dataclasses import dataclass -from enum import IntEnum - -POLICY_RESERVED_1_BIT = 17 -REPORT_SIZE = 0x4A0 # 1184 bytes -SIGNATURE_OFFSET = 0x2A0 -ECDSA_RS_SIZE = 72 -ECDSA_P384_SHA384_SIGNATURE_SIZE = ECDSA_RS_SIZE + ECDSA_RS_SIZE - -ZEN3ZEN4_FAMILY = 0x19 -ZEN5_FAMILY = 0x1A -MILAN_MODEL = 0 | 1 -GENOA_MODEL = (1 << 4) | 1 -TURIN_MODEL = 2 - -class ReportSigner(IntEnum): - VcekReportSigner = 0 - # VlekReportSigner is the SIGNING_KEY value for if the VLEK signed the attestation report. - VlekReportSigner = 1 - endorseReserved2 = 2 - endorseReserved3 = 3 - endorseReserved4 = 4 - endorseReserved5 = 5 - endorseReserved6 = 6 - # NoneReportSigner is the SIGNING_KEY value for if the attestation report is not signed. - NoneReportSigner = 7 - -# SignerInfo represents information about the signing circumstances for the attestation report. -class SignerInfo: - # SigningKey represents kind of key by which a report was signed. - signingKey: ReportSigner - # MaskChipKey is true if the host chose to enable CHIP_ID masking, to cause the report's CHIP_ID - # to be all zeros. - maskChipKey: bool - # AuthorKeyEn is true if the VM is launched with an IDBLOCK that includes an author key. - authorKeyEn: bool - -@dataclass -class TCBParts: - """Represents the decomposed parts of a TCB version""" - ucode_spl: int - snp_spl: int - tee_spl: int - bl_spl: int - - def __str__(self) -> str: - """Return a human-friendly string with all component SPL values.""" - # Print fields in order starting with the least-significant component (bl_spl) - return ( - "TCBParts(" # Opening - f"bl_spl=0x{self.bl_spl:02x}, " - f"tee_spl=0x{self.tee_spl:02x}, " - f"snp_spl=0x{self.snp_spl:02x}, " - f"ucode_spl=0x{self.ucode_spl:02x})" - ) - - @classmethod - def from_int(cls, tcb: int) -> "TCBParts": - """Build a TCBParts instance from a 64-bit packed TCB value.""" - return cls( - ucode_spl=((tcb >> 56) & 0xff), - snp_spl=((tcb >> 48) & 0xff), - tee_spl=((tcb >> 8) & 0xff), - bl_spl=((tcb >> 0) & 0xff), - ) - - def meets_minimum(self, minimum: "TCBParts") -> bool: - """Check if this TCB meets minimum requirements (component-wise).""" - return ( - self.bl_spl >= minimum.bl_spl and - self.tee_spl >= minimum.tee_spl and - self.snp_spl >= minimum.snp_spl and - self.ucode_spl >= minimum.ucode_spl - ) - -@dataclass -class SnpPlatformInfo: - """Decoded view of the 64-bit PLATFORM_INFO field.""" - - smt_enabled: bool - tsme_enabled: bool - ecc_enabled: bool - rapl_disabled: bool - ciphertext_hiding_dram_enabled: bool - alias_check_complete: bool - tio_enabled: bool - - @classmethod - def from_int(cls, value: int) -> "SnpPlatformInfo": - return cls( - smt_enabled=bool(value & (1 << 0)), - tsme_enabled=bool(value & (1 << 1)), - ecc_enabled=bool(value & (1 << 2)), - rapl_disabled=bool(value & (1 << 3)), - ciphertext_hiding_dram_enabled=bool(value & (1 << 4)), - alias_check_complete=bool(value & (1 << 5)), - tio_enabled=bool(value & (1 << 7)) - ) - - def __str__(self) -> str: # pragma: no cover – formatting helper - return ( - "SnpPlatformInfo(" # opening - f"SMTEnabled={self.smt_enabled}, " - f"TSMEEnabled={self.tsme_enabled}, " - f"ECCEnabled={self.ecc_enabled}, " - f"RAPLDisabled={self.rapl_disabled}, " - f"CiphertextHidingDRAMEnabled={self.ciphertext_hiding_dram_enabled}, " - f"AliasCheckComplete={self.alias_check_complete}, " - f"TIOEnabled={self.tio_enabled})" - ) - - -@dataclass -class SnpPolicy: - """Decoded view of the 64-bit POLICY field (bits 0-20).""" - - abi_minor: int - abi_major: int - smt: bool - migrate_ma: bool - debug: bool - single_socket: bool - cxl_allowed: bool - mem_aes256_xts: bool - rapl_dis: bool - ciphertext_hiding_dram: bool - page_swap_disabled: bool - - def __str__(self) -> str: # pragma: no cover – formatting helper - return ( - "SnpPolicy(" # opening - f"ABIMajor={self.abi_major}, ABIMinor={self.abi_minor}, " - f"SMT={self.smt}, MigrateMA={self.migrate_ma}, Debug={self.debug}, " - f"SingleSocket={self.single_socket}, CXLAllowed={self.cxl_allowed}, " - f"MemAES256XTS={self.mem_aes256_xts}, RAPLDis={self.rapl_dis}, " - f"CipherTextHidingDRAM={self.ciphertext_hiding_dram}, PageSwapDisabled={self.page_swap_disabled})" - ) - - @classmethod - def from_int(cls, value: int) -> "SnpPolicy": - """Parse the guest policy bit-field following AMD SEV-SNP spec.""" - return cls( - abi_minor=value & 0xFF, - abi_major=(value >> 8) & 0xFF, - smt=bool(value & (1 << 16)), - migrate_ma=bool(value & (1 << 18)), - debug=bool(value & (1 << 19)), - single_socket=bool(value & (1 << 20)), - cxl_allowed=bool(value & (1 << 21)), - mem_aes256_xts=bool(value & (1 << 22)), - rapl_dis=bool(value & (1 << 23)), - ciphertext_hiding_dram=bool(value & (1 << 24)), - page_swap_disabled=bool(value & (1 << 25)), - ) - -@dataclass -class Report: - """SEV-SNP attestation report""" - version: int # Should be 2 for revision 1.55, 3 for revision 1.56, 5 for revision 1.58 - guest_svn: int - policy: int - policy_parsed: SnpPolicy - family_id: bytes # Should be 16 bytes long - image_id: bytes # Should be 16 bytes long - vmpl: int - signature_algo: int - current_tcb: int - platform_info: int - platform_info_parsed: SnpPlatformInfo - signer_info: int # AuthorKeyEn, MaskChipKey, SigningKey - signer_info_parsed: SignerInfo - report_data: bytes # Should be 64 bytes long - measurement: bytes # Should be 48 bytes long - host_data: bytes # Should be 32 bytes long - id_key_digest: bytes # Should be 48 bytes long - author_key_digest: bytes # Should be 48 bytes long - report_id: bytes # Should be 32 bytes long - report_id_ma: bytes # Should be 32 bytes long - reported_tcb: int - chip_id: bytes # Should be 64 bytes long - committed_tcb: int - current_build: int - current_minor: int - current_major: int - committed_build: int - committed_minor: int - committed_major: int - launch_tcb: int - signed_data: bytes - signature: bytes # Should be 512 bytes long - family: bytes - model: bytes - stepping: bytes - productName: str - - def __init__(self, data: bytes): - """ - Parse an attestation report from raw bytes in SEV SNP ABI format. - - Args: - data: Raw bytes of the attestation report - Returns: - Report object containing parsed data - """ - - if len(data) < REPORT_SIZE: - raise ValueError(f"Array size is 0x{len(data):x}, an SEV-SNP attestation report size is 0x{REPORT_SIZE:x}") - - # Parse all fields using little-endian byte order - self.version = int.from_bytes(data[0x00:0x04], byteorder='little') - self.guest_svn = int.from_bytes(data[0x04:0x08], byteorder='little') - self.policy = int.from_bytes(data[0x08:0x10], byteorder='little') - - # Check reserved bit must be 1 - if not (self.policy & (1 << POLICY_RESERVED_1_BIT)): - raise ValueError(f"policy[{POLICY_RESERVED_1_BIT}] is reserved, must be 1, got 0") - - # Check bits 63-26 must be zero - if self.policy >> 26: - raise ValueError("policy bits 63-26 must be zero") - - self.family_id = data[0x10:0x20] # 16 bytes - self.image_id = data[0x20:0x30] # 16 bytes - self.vmpl = int.from_bytes(data[0x30:0x34], byteorder='little') - self.signature_algo = int.from_bytes(data[0x34:0x38], byteorder='little') - self.current_tcb = int.from_bytes(data[0x38:0x40], byteorder='little') - - try: - mbz64(int(self.current_tcb), "current_tcb", 47, 16) - except ValueError as e: - raise ValueError(f"current_tcb not correctly formed: {e}") - - self.platform_info = int.from_bytes(data[0x40:0x48], byteorder='little') - # Decode additional helper structures for easier consumption later. - self.policy_parsed = SnpPolicy.from_int(self.policy) - self.platform_info_parsed = SnpPlatformInfo.from_int(self.platform_info) - self.signer_info = int.from_bytes(data[0x48:0x4C], byteorder='little') - - self.signer_info_parsed = SignerInfo() - try: - mbz64(int(self.signer_info), "signer_info", 31, 5) - except ValueError as e: - raise ValueError(f"signer_info not correctly formed: {e}") - - self.signer_info_parsed.signingKey = ReportSigner((self.signer_info >> 2) & 7) - if self.signer_info_parsed.signingKey != ReportSigner.VcekReportSigner: - raise ValueError(f"This implementation only supports VCEK signed reports. Got {self.signer_info_parsed.signingKey}") - self.signer_info_parsed.maskChipKey = (self.signer_info & 2) != 0 - self.signer_info_parsed.authorKeyEn = (self.signer_info & 1) != 0 - - try: - mbz(data, 0x4C, 0x50) - except ValueError as e: - raise ValueError(f"report_data not correctly formed: {e}") - - # 0x4C-0x50 is MBZ (Must Be Zero) - self.report_data = data[0x50:0x90] # 64 bytes - self.measurement = data[0x90:0xC0] # 48 bytes - self.host_data = data[0xC0:0xE0] # 32 bytes - self.id_key_digest = data[0xE0:0x110] # 48 bytes - self.author_key_digest = data[0x110:0x140] # 48 bytes - self.report_id = data[0x140:0x160] # 32 bytes - self.report_id_ma = data[0x160:0x180] # 32 bytes - self.reported_tcb = int.from_bytes(data[0x180:0x188], byteorder='little') - - try: - mbz64(int(self.reported_tcb), "reported_tcb", 47, 16) - except ValueError as e: - raise ValueError(f"reported_tcb not correctly formed: {e}") - - mbzLo = 0x188 - # Version specific parsing - if self.version >= 3: # Report Version 3 - self.family = data[0x188] - self.model = data[0x189] - self.stepping = data[0x18A] - self._init_product_name() - mbzLo = 0x18B - elif self.version == 2: # Report Version 2 - self.family = ZEN3ZEN4_FAMILY - self.model = GENOA_MODEL - self.stepping = 0x01 - self.productName = "Genoa" - else: - raise ValueError("Unknown report version") - - try: - mbz(data, mbzLo, 0x1A0) - except ValueError as e: - raise ValueError(f"report_data not correctly formed: {e}") - - self.chip_id = data[0x1A0:0x1E0] # 64 bytes - self.committed_tcb = int.from_bytes(data[0x1E0:0x1E8], byteorder='little') - - try: - mbz64(int(self.committed_tcb), "committed_tcb", 47, 16) - except ValueError as e: - raise ValueError(f"committed_tcb not correctly formed: {e}") - - # Version fields - self.current_build = data[0x1E8] - self.current_minor = data[0x1E9] - self.current_major = data[0x1EA] - - try: - mbz(data, 0x1EB, 0x1EC) - except ValueError as e: - raise ValueError(f"report_data not correctly formed: {e}") - - self.committed_build = data[0x1EC] - self.committed_minor = data[0x1ED] - self.committed_major = data[0x1EE] - - try: - mbz(data, 0x1EF, 0x1F0) - except ValueError as e: - raise ValueError(f"report_data not correctly formed: {e}") - - self.launch_tcb = int.from_bytes(data[0x1F0:0x1F8], byteorder='little') - - try: - mbz64(int(self.launch_tcb), "launch_tcb", 47, 16) - except ValueError as e: - raise ValueError(f"launch_tcb not correctly formed: {e}") - - try: - mbz(data, 0x1F8, SIGNATURE_OFFSET) - except ValueError as e: - raise ValueError(f"report_data not correctly formed: {e}") - - if self.signature_algo == 1: # ECDSA P-384 SHA-384 - try: - mbz(data, SIGNATURE_OFFSET+ECDSA_P384_SHA384_SIGNATURE_SIZE, REPORT_SIZE) - except ValueError as e: - raise ValueError(f"report_data not correctly formed: {e}") - - self.signed_data = data[0:SIGNATURE_OFFSET] - self.signature = data[SIGNATURE_OFFSET:REPORT_SIZE] - - def get_fms(self): - return self.family, self.model, self.stepping - - def _init_product_name(self): - # Combined extended values - self.productName = "Unknown" - if self.family == ZEN3ZEN4_FAMILY: - if self.model == MILAN_MODEL: - self.productName = "Milan" - elif self.model == GENOA_MODEL: - self.productName = "Genoa" - elif self.family == ZEN5_FAMILY: - if self.model == TURIN_MODEL: - self.productName = "Turin" - - def print_report(self): - """Print all relevant fields of the SEV-SNP attestation report in a human-readable format.""" - print("=== SEV-SNP Attestation Report ===") - print(f"Version: {self.version}") - print(f"Guest SVN: {self.guest_svn}") - print(f"Policy: 0x{self.policy:x}") - print(f" -> {self.policy_parsed}") - print(f"Family ID: {self.family_id.hex()}") - print(f"Image ID: {self.image_id.hex()}") - print(f"VMPL: {self.vmpl}") - print(f"Signature Algorithm: {self.signature_algo}") - print(f"Current TCB: 0x{self.current_tcb:x}") - print(f" -> {TCBParts.from_int(self.current_tcb)}") - print(f"Platform Info: 0x{self.platform_info:x}") - print(f" -> {self.platform_info_parsed}") - print(f"Signer Info: 0x{self.signer_info:x}") - print(f" - Signing Key: {self.signer_info_parsed.signingKey}") - print(f" - Mask Chip Key: {self.signer_info_parsed.maskChipKey}") - print(f" - Author Key Enabled: {self.signer_info_parsed.authorKeyEn}") - print(f"Report Data: {self.report_data.hex()}") - print(f"Measurement: {self.measurement.hex()}") - print(f"Host Data: {self.host_data.hex()}") - print(f"ID Key Digest: {self.id_key_digest.hex()}") - print(f"Author Key Digest: {self.author_key_digest.hex()}") - print(f"Report ID: {self.report_id.hex()}") - print(f"Report ID MA: {self.report_id_ma.hex()}") - print(f"Reported TCB: 0x{self.reported_tcb:x}") - print(f" -> {TCBParts.from_int(self.reported_tcb)}") - print(f"Chip ID: {self.chip_id.hex()}") - print(f"Committed TCB: 0x{self.committed_tcb:x}") - print(f" -> {TCBParts.from_int(self.committed_tcb)}") - print(f"Current Version: {self.current_major}.{self.current_minor}.{self.current_build}") - print(f"Committed Version: {self.committed_major}.{self.committed_minor}.{self.committed_build}") - print(f"Launch TCB: 0x{self.launch_tcb:x}") - print(f" -> {TCBParts.from_int(self.launch_tcb)}") - print(f"Product Name: {self.productName}") - if hasattr(self, 'family') and hasattr(self, 'model') and hasattr(self, 'stepping'): - print(f"CPU: Family=0x{self.family:02x}, Model=0x{self.model:02x}, Stepping=0x{self.stepping:02x}") - print(f"Signature Length: {len(self.signature)} bytes") - print("=" * 40) - -## HELPER FUNCTIONS - -def find_non_zero(data: bytes, lo: int, hi: int) -> int: - """ - Returns the first index which is not zero, otherwise returns hi. - - Args: - data: Bytes object to search through - lo: Starting index (inclusive) - hi: Ending index (exclusive) - Returns: - Index of first non-zero byte, or hi if all bytes are zero - """ - for i in range(lo, hi): - if data[i] != 0: - return i - return hi - -def mbz(data: bytes, lo: int, hi: int) -> None: - """ - Checks if a range of bytes is all zeros. - - Args: - data: Bytes object to check - lo: Starting index (inclusive) - hi: Ending index (exclusive) - Raises: - ValueError: If any byte in the range is non-zero - """ - first_non_zero = find_non_zero(data, lo, hi) - if first_non_zero != hi: - # Convert the slice to hex string for error message - hex_str = data[lo:hi].hex() - raise ValueError(f"mbz range [0x{lo:x}:0x{hi:x}] not all zero: {hex_str}") - -def mbz64(data: int, base: str, hi: int, lo: int) -> None: - """ - Checks if a range of bits in an integer is all zeros. - - Args: - data: Integer to check - base: String identifier for error message - hi: Highest bit position (inclusive) - lo: Lowest bit position (inclusive) - Raises: - ValueError: If any bit in the range is non-zero - """ - # Create mask for the bit range - mask = (1 << (hi - lo + 1)) - 1 - # Extract and check the bits - bits = (data >> lo) & mask - if bits != 0: - raise ValueError(f"mbz range {base}[0x{lo:x}:0x{hi:x}] not all zero: {hex(data)}") - diff --git a/src/tinfoil/attestation/validate.py b/src/tinfoil/attestation/validate.py deleted file mode 100644 index d1a9997..0000000 --- a/src/tinfoil/attestation/validate.py +++ /dev/null @@ -1,288 +0,0 @@ -from dataclasses import dataclass, field -from typing import Optional, List, Dict - -from .abi_sevsnp import ( - Report, - TCBParts, - SnpPolicy, - SnpPlatformInfo, - ReportSigner, -) -from .verify import CertificateChain - -@dataclass -class ValidationOptions: - """ - Verification options for an SEV-SNP attestation report. - Any attribute left as ``None`` / empty will not be checked - by the validation routine. - """ - # Policy / version constraints - guest_policy: Optional[SnpPolicy] = None - minimum_guest_svn: Optional[int] = None - minimum_build: Optional[int] = None # Firmware build (uint8) - minimum_version: Optional[int] = None # Firmware API version (uint16) - - # TCB requirements - minimum_tcb: Optional[TCBParts] = None - minimum_launch_tcb: Optional[TCBParts] = None - permit_provisional_firmware: bool = False - - # Field equality checks (length is not enforced here; caller must ensure correctness) - report_data: Optional[bytes] = None # 64 bytes - host_data: Optional[bytes] = None # 32 bytes - image_id: Optional[bytes] = None # 16 bytes - family_id: Optional[bytes] = None # 16 bytes - report_id: Optional[bytes] = None # 32 bytes - report_id_ma: Optional[bytes] = None # 32 bytes - measurement: Optional[bytes] = None # 48 bytes - chip_id: Optional[bytes] = None # 64 bytes - - # Misc - platform_info: Optional[SnpPlatformInfo] = None - vmpl: Optional[int] = None # Expected VMPL (0-3) - - # TODO: ID-block / author key requirements - require_author_key: bool = False - require_id_block: bool = False - # trusted_author_keys: List[x509.Certificate] = field(default_factory=list) - # trusted_author_key_hashes: List[bytes] = field(default_factory=list) - # trusted_id_keys: List[x509.Certificate] = field(default_factory=list) - # trusted_id_key_hashes: List[bytes] = field(default_factory=list) - - # TODO:Extended certificate-table options - # cert_table_options: Dict[str, CertEntryOption] = field(default_factory=dict) - - -def validate_report(report: Report, chain: CertificateChain, options: ValidationOptions): - """ - Validate the supplied SEV-SNP attestation report according to *options*. - Raises ValueError if validation fails. - """ - - # Policy constraints - if options.guest_policy is not None: - _validate_policy(report.policy_parsed, options.guest_policy) - - if options.minimum_guest_svn is not None: - if report.guest_svn < options.minimum_guest_svn: - raise ValueError(f"Guest SVN {report.guest_svn} is less than minimum required {options.minimum_guest_svn}") - - if options.minimum_build is not None: - if report.current_build < options.minimum_build: - raise ValueError(f"Current SNP firmware build number {report.current_build} is less than minimum required {options.minimum_build}") - if report.committed_build < options.minimum_build: - raise ValueError(f"Committed SNP firmware build number {report.committed_build} is less than minimum required {options.minimum_build}") - - if options.minimum_version is not None: - # Combine major/minor into single version number for comparison - current_version = (report.current_major << 8) | report.current_minor - committed_version = (report.committed_major << 8) | report.committed_minor - if current_version < options.minimum_version: - raise ValueError(f"Current SNP firmwareversion {report.current_major}.{report.current_minor} is less than minimum required {options.minimum_version >> 8}.{options.minimum_version & 0xff}") - if committed_version < options.minimum_version: - raise ValueError(f"Committed SNP firmware version {report.committed_major}.{report.committed_minor} is less than minimum required {options.minimum_version >> 8}.{options.minimum_version & 0xff}") - - # TCB requirements - if options.minimum_tcb is not None: - current_tcb_parts = TCBParts.from_int(report.current_tcb) - committed_tcb_parts = TCBParts.from_int(report.committed_tcb) - reported_tcb_parts = TCBParts.from_int(report.reported_tcb) - if not current_tcb_parts.meets_minimum(options.minimum_tcb): - raise ValueError(f"Current TCB {current_tcb_parts} does not meet minimum requirements {options.minimum_tcb}") - if not committed_tcb_parts.meets_minimum(options.minimum_tcb): - raise ValueError(f"Committed TCB {committed_tcb_parts} does not meet minimum requirements {options.minimum_tcb}") - if not reported_tcb_parts.meets_minimum(options.minimum_tcb): - raise ValueError(f"Reported TCB {reported_tcb_parts} does not meet minimum requirements {options.minimum_tcb}") - - # VCEK-specific TCB check - chain.validate_vcek_tcb(TCBParts.from_int(report.reported_tcb)) - - if options.minimum_launch_tcb is not None: - launch_tcb_parts = TCBParts.from_int(report.launch_tcb) - if not launch_tcb_parts.meets_minimum(options.minimum_launch_tcb): - raise ValueError(f"Launch TCB {launch_tcb_parts} does not meet minimum requirements {options.minimum_launch_tcb}") - - # Field equality checks - if options.report_data is not None: - if len(report.report_data) != 64: - raise ValueError(f"Report data length is {len(report.report_data)}, expected 64 bytes") - if report.report_data != options.report_data: - raise ValueError(f"Report data mismatch: got {report.report_data.hex()}, expected {options.report_data.hex()}") - - if options.host_data is not None: - if len(report.host_data) != 32: - raise ValueError(f"Host data length is {len(report.host_data)}, expected 32 bytes") - if report.host_data != options.host_data: - raise ValueError(f"Host data mismatch: got {report.host_data.hex()}, expected {options.host_data.hex()}") - - if options.image_id is not None: - if len(report.image_id) != 16: - raise ValueError(f"Image ID length is {len(report.image_id)}, expected 16 bytes") - if report.image_id != options.image_id: - raise ValueError(f"Image ID mismatch: got {report.image_id.hex()}, expected {options.image_id.hex()}") - - if options.family_id is not None: - if len(report.family_id) != 16: - raise ValueError(f"Family ID length is {len(report.family_id)}, expected 16 bytes") - if report.family_id != options.family_id: - raise ValueError(f"Family ID mismatch: got {report.family_id.hex()}, expected {options.family_id.hex()}") - - if options.report_id is not None: - if len(report.report_id) != 32: - raise ValueError(f"Report ID length is {len(report.report_id)}, expected 32 bytes") - if report.report_id != options.report_id: - raise ValueError(f"Report ID mismatch: got {report.report_id.hex()}, expected {options.report_id.hex()}") - - if options.report_id_ma is not None: - if len(report.report_id_ma) != 32: - raise ValueError(f"Report ID MA length is {len(report.report_id_ma)}, expected 32 bytes") - if report.report_id_ma != options.report_id_ma: - raise ValueError(f"Report ID MA mismatch: got {report.report_id_ma.hex()}, expected {options.report_id_ma.hex()}") - - if options.measurement is not None: - if len(report.measurement) != 48: - raise ValueError(f"Measurement length is {len(report.measurement)}, expected 48 bytes") - if report.measurement != options.measurement: - raise ValueError(f"Measurement mismatch: got {report.measurement.hex()}, expected {options.measurement.hex()}") - - if options.chip_id is not None: - if len(report.chip_id) != 64: - raise ValueError(f"Chip ID length is {len(report.chip_id)}, expected 64 bytes") - if report.chip_id != options.chip_id: - raise ValueError(f"Chip ID mismatch: got {report.chip_id.hex()}, expected {options.chip_id.hex()}") - - # VCEK-specific CHIP_ID ↔ HWID equality check - if report.signer_info_parsed.signingKey == ReportSigner.VcekReportSigner: - if report.signer_info_parsed.maskChipKey and any(report.chip_id): - raise ValueError("maskChipKey is set but CHIP_ID is not zeroed") - if not report.signer_info_parsed.maskChipKey: - chain.validate_vcek_hwid(report.chip_id) - - # Platform info check - if options.platform_info is not None: - _validate_platform_info(report.platform_info_parsed, options.platform_info) - - # VMPL check - if options.vmpl is not None: # Must be between 0 and 3 and equal to the expected value - if not (0 <= report.vmpl <= 3): - raise ValueError(f"VMPL {report.vmpl} is not in valid range 0-3") - if report.vmpl != options.vmpl: - raise ValueError(f"VMPL mismatch: got {report.vmpl}, expected {options.vmpl}") - - # Provisional firmware check - we only support permit_provisional_firmware = False - if options.permit_provisional_firmware: - # Not supported - reject any request for provisional firmware - raise ValueError("Provisional firmware is not supported") - - # When permit_provisional_firmware = False, committed and current values must be equal - if report.committed_build != report.current_build: - raise ValueError(f"Committed build {report.committed_build} does not match current build {report.current_build}") - if report.committed_minor != report.current_minor: - raise ValueError(f"Committed minor version {report.committed_minor} does not match current minor version {report.current_minor}") - if report.committed_major != report.current_major: - raise ValueError(f"Committed major version {report.committed_major} does not match current major version {report.current_major}") - if report.committed_tcb != report.current_tcb: - raise ValueError(f"Committed TCB 0x{report.committed_tcb:x} does not match current TCB 0x{report.current_tcb:x}") - - # ID-block / author key requirements - if options.require_author_key or options.require_id_block: - # Not supported yet - raise ValueError("ID-block and author key requirements are not supported yet") - - -def _validate_policy(report_policy: SnpPolicy, required: SnpPolicy): - """ - Validate policy with security-aware checks. - - Logic follows Go reference implementation: - - Check ABI version compatibility - - Reject unauthorized capabilities (report has them, required doesn't allow) - - Reject missing required restrictions/features - - Raises ValueError if validation fails. - """ - # ABI version check - required version must not be greater than report version - if _compare_policy_versions(required, report_policy) > 0: - raise ValueError(f"Required ABI version ({required.abi_major}.{required.abi_minor}) is greater than report's ABI version ({report_policy.abi_major}.{report_policy.abi_minor})") - - # Unauthorized capabilities (report has them enabled, but required doesn't allow) - if not required.migrate_ma and report_policy.migrate_ma: - raise ValueError(f"Found unauthorized migration agent capability. Report policy: {report_policy}, Required policy: {required}") - - if not required.debug and report_policy.debug: - raise ValueError(f"Found unauthorized debug capability. Report policy: {report_policy}, Required policy: {required}") - - if not required.smt and report_policy.smt: - raise ValueError(f"Found unauthorized symmetric multithreading (SMT) capability. Report policy: {report_policy}, Required policy: {required}") - - if not required.cxl_allowed and report_policy.cxl_allowed: - raise ValueError(f"Found unauthorized CXL capability. Report policy: {report_policy}, Required policy: {required}") - - if not required.mem_aes256_xts and report_policy.mem_aes256_xts: - raise ValueError(f"Found unauthorized memory encryption mode. Report policy: {report_policy}, Required policy: {required}") - - # Required restrictions/features (report lacks what required mandates) - if required.single_socket and not report_policy.single_socket: - raise ValueError(f"Required single socket restriction not present. Report policy: {report_policy}, Required policy: {required}") - - if required.mem_aes256_xts and not report_policy.mem_aes256_xts: - raise ValueError(f"Found unauthorized memory encryption mode. Report policy: {report_policy}, Required policy: {required}") - - if required.rapl_dis and not report_policy.rapl_dis: - raise ValueError(f"Found unauthorized RAPL capability. Report policy: {report_policy}, Required policy: {required}") - - if required.ciphertext_hiding_dram and not report_policy.ciphertext_hiding_dram: - raise ValueError(f"Ciphertext hiding in DRAM isn't enforced. Report policy: {report_policy}, Required policy: {required}") - - if required.page_swap_disabled and not report_policy.page_swap_disabled: - raise ValueError(f"Page swap isn't disabled. Report policy: {report_policy}, Required policy: {required}") - -def _compare_policy_versions(required: SnpPolicy, report: SnpPolicy) -> int: - """ - Compare policy ABI versions. - Returns: - > 0 if required version is greater than report version - = 0 if versions are equal - < 0 if required version is less than report version - """ - # Compare major version first - if required.abi_major != report.abi_major: - return required.abi_major - report.abi_major - - # If major versions are equal, compare minor versions - return required.abi_minor - report.abi_minor - - -def _validate_platform_info(report_info: SnpPlatformInfo, required: SnpPlatformInfo): - """ - Validate platform info with security-aware checks. - - Logic follows Go reference implementation: - - If report has a feature enabled that required doesn't allow -> FAIL - - If report lacks a feature that required mandates -> FAIL - - Raises ValueError if validation fails. - """ - # Unauthorized features (report has it enabled, but required doesn't allow it) - if report_info.smt_enabled and not required.smt_enabled: - raise ValueError(f"Unauthorized platform feature SMT enabled. Report platform info: {report_info}, Required platform info: {required}") - - # Required features (report lacks something that required mandates) - if not report_info.ecc_enabled and required.ecc_enabled: - raise ValueError(f"Required platform feature ECC not enabled. Report platform info: {report_info}, Required platform info: {required}") - - if not report_info.tsme_enabled and required.tsme_enabled: - raise ValueError(f"Required platform feature TSME not enabled. Report platform info: {report_info}, Required platform info: {required}") - - if not report_info.rapl_disabled and required.rapl_disabled: - raise ValueError(f"Required platform feature RAPL not disabled. Report platform info: {report_info}, Required platform info: {required}") - - if not report_info.ciphertext_hiding_dram_enabled and required.ciphertext_hiding_dram_enabled: - raise ValueError(f"Required ciphertext hiding in DRAM not enforced. Report platform info: {report_info}, Required platform info: {required}") - - if not report_info.alias_check_complete and required.alias_check_complete: - raise ValueError(f"Required memory alias check hasn't been completed. Report platform info: {report_info}, Required platform info: {required}") - - if not report_info.tio_enabled and required.tio_enabled: - raise ValueError(f"Required TIO not enabled. Report platform info: {report_info}, Required platform info: {required}") \ No newline at end of file diff --git a/src/tinfoil/attestation/verify.py b/src/tinfoil/attestation/verify.py deleted file mode 100644 index 96fdf01..0000000 --- a/src/tinfoil/attestation/verify.py +++ /dev/null @@ -1,478 +0,0 @@ -#!/usr/bin/env python3 -""" -Simplified AMD SEV-SNP Attestation Verifier (VCEK Chain Only) -""" - -import os -from dataclasses import dataclass -from typing import Dict, TypeAlias -import binascii -import requests -from OpenSSL import crypto -import platformdirs - -from .abi_sevsnp import (Report, ReportSigner, TCBParts) -from .genoa_cert_chain import (ARK_CERT, ASK_CERT) - -from cryptography import x509 -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import ec, utils -from cryptography.x509.oid import ObjectIdentifier -import warnings -from cryptography.utils import CryptographyDeprecationWarning - -# Type alias for certificate extensions -Extensions: TypeAlias = Dict[ObjectIdentifier, bytes] - -# VCEK cache directory setup (can stay at module level) -_VCEK_CACHE_DIR = platformdirs.user_cache_dir("tinfoil", "tinfoil") -os.makedirs(_VCEK_CACHE_DIR, exist_ok=True) - -class SnpOid: - """OID extensions for the VCEK, used to verify attestation report""" - BOOTLOADER = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.1") - TEE = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.2") - SNP = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.3") - STRUCT_VERSION = ObjectIdentifier("1.3.6.1.4.1.3704.1.1") - PRODUCT_NAME_1 = ObjectIdentifier("1.3.6.1.4.1.3704.1.2") - BL_SPL = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.1") - TEE_SPL = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.2") - SNP_SPL = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.3") - SPL4 = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.4") - SPL5 = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.5") - SPL6 = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.6") - SPL7 = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.7") - UCODE = ObjectIdentifier("1.3.6.1.4.1.3704.1.3.8") - HWID = ObjectIdentifier("1.3.6.1.4.1.3704.1.4") - CSP_ID = ObjectIdentifier("1.3.6.1.4.1.3704.1.5") - -class CertificateChain: - """Represents the SEV certificate chain (ARK > ASK > VCEK)""" - ark: x509.Certificate - ask: x509.Certificate - vcek: x509.Certificate - - def __init__(self, ark: x509.Certificate, ask: x509.Certificate, vcek: x509.Certificate): - self.ark = ark - self.ask = ask - self.vcek = vcek - - @staticmethod - def _vcek_cache_path(product_name: str, chip_id: bytes, reported_tcb: int) -> str: - """ - Build a deterministic filename for a given (product, chip_id, tcb). - Uses the module-level _VCEK_CACHE_DIR. - """ - chip_hex = chip_id.hex() - tcb_hex = f"{reported_tcb:016x}" - filename = f"VCEK_{product_name}_{chip_hex}_{tcb_hex}.der" - return os.path.join(_VCEK_CACHE_DIR, filename) - - @classmethod - def from_files(cls, ark_path: str, ask_path: str, vcek_path: str) -> 'CertificateChain': - """Alternative constructor to load certificates from files""" - ark = cls._load_cert(ark_path) - ask = cls._load_cert(ask_path) - vcek = cls._load_cert(vcek_path) - return cls(ark=ark, ask=ask, vcek=vcek) - - @classmethod - def from_report(cls, report:Report) -> 'CertificateChain': - productName: str = report.productName - - if productName != "Genoa": - raise ValueError("This implementation only supports Genoa processors") - - # Use the hardcoded certificate chain - ark = x509.load_pem_x509_certificate(ARK_CERT) - ask = x509.load_pem_x509_certificate(ASK_CERT) - - signer_info = report.signer_info_parsed - - if signer_info.signingKey != ReportSigner.VcekReportSigner: - raise ValueError("This implementation only supports VCEK signed reports") - - # Fetch (or load) the VCEK certificate - vcek_url = _VCEKCertURL(productName, report.chip_id, report.reported_tcb) - cache_path = cls._vcek_cache_path(productName, report.chip_id, report.reported_tcb) - - # 1. Try the on‑disk cache - if os.path.isfile(cache_path): - with open(cache_path, "rb") as fh: - vcek_cert_data = fh.read() - else: - # 2. Cache miss → fetch from the KDS endpoint - try: - response = requests.get(vcek_url, timeout=10) - response.raise_for_status() - vcek_cert_data = response.content - # Persist to cache so the next call is instant - with open(cache_path, "wb") as fh: - fh.write(vcek_cert_data) - except requests.RequestException as e: - raise ValueError(f"Failed to fetch VCEK certificate: {e}") from e - - # Parse the (cached or freshly‑downloaded) certificate - try: - # cryptography 46+ emits a deprecation warning for non‑positive serial numbers. - # Suppress this specific deprecation warning locally when parsing VCEK DER. - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - message=r"Parsed a serial number which wasn't positive", - category=CryptographyDeprecationWarning, - ) - vcek = x509.load_der_x509_certificate(vcek_cert_data) - except Exception as e: - # Corrupted cache? Remove and propagate error so caller can retry. - if os.path.exists(cache_path): - try: - os.remove(cache_path) - except OSError: - pass - raise ValueError(f"Failed to parse VCEK certificate: {e}") from e - - # Return the complete certificate chain - return cls(ark=ark, ask=ask, vcek=vcek) - - @staticmethod - def _load_cert(filepath: str) -> x509.Certificate: - """Load an X.509 certificate from file""" - _, ext = os.path.splitext(filepath) - with open(filepath, 'rb') as f: - data = f.read() - - if ext.lower() == '.pem': - return x509.load_pem_x509_certificate(data) - else: - # Suppress cryptography deprecation warnings for DER parsing as above. - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - message=r"Parsed a serial number which wasn't positive", - category=CryptographyDeprecationWarning, - ) - return x509.load_der_x509_certificate(data) - - def verify_chain(self) -> bool: - # Validate VCEK format - try: - self._validate_vcek_format() - except ValueError as e: - print(f"VCEK certificate validation failed: {e}") - return False - - # Validate ARK and ASK format - try: - self._validate_ark_format() - except ValueError as e: - print(f"ARK certificate validation failed: {e}") - return False - try: - self._validate_ask_format() - except ValueError as e: - print(f"ASK certificate validation failed: {e}") - return False - - # Verify the certificate chain using OpenSSL - try: - # Create a store and add the root (ARK) certificate - store = crypto.X509Store() - store.add_cert(crypto.X509.from_cryptography(self.ark)) - - # Add the intermediate (ASK) certificate - store.add_cert(crypto.X509.from_cryptography(self.ask)) - - # Create a store context - store_ctx = crypto.X509StoreContext(store, crypto.X509.from_cryptography(self.vcek)) - - # Verify the certificate - try: - store_ctx.verify_certificate() - return True - except crypto.X509StoreContextError as e: - print(f"Certificate chain verification failed: {e}") - return False - - except Exception as e: - print(f"Error during chain verification: {e}") - return False - - def _validate_ark_format(self): - if self.ark.version != x509.Version.v3: - raise ValueError("ARK certificate version is not 3") - if not _validateAmdLocation(self.ark.issuer): - raise ValueError("ARK certificate issuer is not a valid AMD location") - if not _validateAmdLocation(self.ark.subject): - raise ValueError("ARK certificate subject is not a valid AMD location") - - cn = self.ark.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value - if cn != "ARK-Genoa": - raise ValueError(f"ARK certificate subject common name is not ARK-Genoa but {cn}") - - # TODO add support for Certificate Revocation Lists - # NOTE Here the go implementation cross check Sev format with the X509 certificate but we only trust the certificate we ship with the code - - def _validate_ask_format(self): - # Validate ASK format - if self.ask.version != x509.Version.v3: - raise ValueError("ASK certificate version is not 3") - if not _validateAmdLocation(self.ask.issuer): - raise ValueError("ASK certificate issuer is not a valid AMD location") - if not _validateAmdLocation(self.ask.subject): - raise ValueError("ASK certificate subject is not a valid AMD location") - - cn = self.ask.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value - if cn != "SEV-Genoa": - raise ValueError(f"ASK certificate subject common name is not SEV-Genoa but {cn}") - - # TODO add support for Certificate Revocation Lists - # NOTE Here the go implementation cross check Sev format with the X509 certificate but we only trust the certificate we ship with the code - - def _validate_vcek_format(self): - """Validate the format of a VCEK certificate""" - - if self.vcek.version != x509.Version.v3: - raise ValueError(f"VCEK certificate version is not 3 but {self.vcek.version}") - - if self.vcek.signature_algorithm_oid != x509.SignatureAlgorithmOID.RSASSA_PSS: - raise ValueError(f"VCEK certificate signature algorithm is not RSASSA_PSS but {self.vcek.signature_algorithm_oid}") - - if self.vcek.public_key_algorithm_oid != x509.PublicKeyAlgorithmOID.EC_PUBLIC_KEY: - raise ValueError(f"VCEK certificate public key algorithm is not ECDSA but {self.vcek.public_key_algorithm_oid}") - - if self.vcek.public_key().curve.name != "secp384r1": - raise ValueError(f"VCEK certificate public key curve is not secp384r1 but {self.vcek.public_key().curve.name}") - - # Validate KDS Cert Subject - if not _validateAmdLocation(self.vcek.subject): - raise ValueError("VCEK certificate subject is not a valid AMD location") - - cn = self.vcek.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value - if cn != "SEV-VCEK": - raise ValueError(f"VCEK certificate subject common name is not SEV-VCEK but {cn}") - - # Get KDS and validate Cert Extensions - extensions = _get_certificate_extensions(self.vcek) - if SnpOid.CSP_ID in extensions: - raise ValueError(f"unexpected CSP_ID in VCEK certificate: {extensions[SnpOid.CSP_ID]}") - - if (SnpOid.HWID not in extensions) or len(extensions[SnpOid.HWID]) != 64: # ChipIDSize - raise ValueError(f"missing HWID extension for VCEK certificate") - - if extensions[SnpOid.PRODUCT_NAME_1] != b'\x16\x05Genoa': - raise ValueError(f"unexpected PRODUCT_NAME_1 in VCEK certificate: {extensions[SnpOid.PRODUCT_NAME_1]}") - - def validate_vcek_tcb(self, tcb: TCBParts): - """Validate the TCB extension in the VCEK certificate matches a given TCB""" - extensions = _get_certificate_extensions(self.vcek) - - if SnpOid.BL_SPL not in extensions: - raise ValueError(f"missing BL_SPL extension for VCEK certificate") - bl_spl = _decode_der_integer(extensions[SnpOid.BL_SPL]) - if bl_spl != tcb.bl_spl: - raise ValueError(f"BL_SPL extension in VCEK certificate does not match tcb.bl_spl: {bl_spl} != {tcb.bl_spl}") - - if SnpOid.TEE_SPL not in extensions: - raise ValueError(f"missing TEE_SPL extension for VCEK certificate") - tee_spl = _decode_der_integer(extensions[SnpOid.TEE_SPL]) - if tee_spl != tcb.tee_spl: - raise ValueError(f"TEE_SPL extension in VCEK certificate does not match tcb.tee_spl: {tee_spl} != {tcb.tee_spl}") - - if SnpOid.SNP_SPL not in extensions: - raise ValueError(f"missing SNP_SPL extension for VCEK certificate") - snp_spl = _decode_der_integer(extensions[SnpOid.SNP_SPL]) - if snp_spl != tcb.snp_spl: - raise ValueError(f"SNP_SPL extension in VCEK certificate does not match tcb.snp_spl: {snp_spl} != {tcb.snp_spl}") - - if SnpOid.UCODE not in extensions: - raise ValueError(f"missing UCODE extension for VCEK certificate") - ucode_spl = _decode_der_integer(extensions[SnpOid.UCODE]) - if ucode_spl != tcb.ucode_spl: - raise ValueError(f"UCODE extension in VCEK certificate does not match tcb.ucode_spl: {ucode_spl} != {tcb.ucode_spl}") - - def validate_vcek_hwid(self, chip_id: bytes): - """Validate the HWID extension in the VCEK certificate matches a given chip id""" - extensions = _get_certificate_extensions(self.vcek) - if SnpOid.HWID not in extensions: - raise ValueError(f"missing HWID extension for VCEK certificate") - if extensions[SnpOid.HWID] != chip_id: - raise ValueError(f"HWID extension in VCEK certificate does not match chip_id: {extensions[SnpOid.HWID]} != {chip_id}") - - -## HELPER FUNCTIONS - -def _get_certificate_extensions(cert: x509.Certificate) -> Extensions: - """Get the extensions from the VCEK certificate""" - extensions = {} - for ext in cert.extensions: - extensions[ext.oid] = ext.value.value - return extensions - -def _decode_der_integer(der_bytes: bytes) -> int: - """Decode a DER-encoded INTEGER""" - if len(der_bytes) < 2 or der_bytes[0] != 0x02: - raise ValueError(f"Invalid DER INTEGER: {der_bytes.hex()}") - - length = der_bytes[1] - if len(der_bytes) != 2 + length: - raise ValueError(f"Invalid DER INTEGER length: {der_bytes.hex()}") - - # Convert the integer bytes to int (big-endian) - return int.from_bytes(der_bytes[2:2+length], byteorder='big') - -def _validateAmdLocation(name: x509.Name) -> bool: - """Validate that the certificate subject name matches AMD's expected values. - - Args: - name: The x509.Name object to validate - - Returns: - bool: True if all fields match expected values, False otherwise - """ - def check_singleton_list(values: list[str], field_name: str, expected: str) -> bool: - if len(values) != 1: - print(f"Expected exactly one {field_name}, got {len(values)}") - return False - if values[0] != expected: - print(f"Unexpected {field_name} value: '{values[0]}', expected '{expected}'") - return False - return True - - # Get the name attributes - country = name.get_attributes_for_oid(x509.NameOID.COUNTRY_NAME) - locality = name.get_attributes_for_oid(x509.NameOID.LOCALITY_NAME) - state = name.get_attributes_for_oid(x509.NameOID.STATE_OR_PROVINCE_NAME) - org = name.get_attributes_for_oid(x509.NameOID.ORGANIZATION_NAME) - org_unit = name.get_attributes_for_oid(x509.NameOID.ORGANIZATIONAL_UNIT_NAME) - - # Extract the values from the attributes - country_values = [attr.value for attr in country] - locality_values = [attr.value for attr in locality] - state_values = [attr.value for attr in state] - org_values = [attr.value for attr in org] - org_unit_values = [attr.value for attr in org_unit] - - # Validate each field - if not check_singleton_list(country_values, "country", "US"): - return False - if not check_singleton_list(locality_values, "locality", "Santa Clara"): - return False - if not check_singleton_list(state_values, "state", "CA"): - return False - if not check_singleton_list(org_values, "organization", "Advanced Micro Devices"): - return False - if not check_singleton_list(org_unit_values, "organizational unit", "Engineering"): - return False - - return True - -def _VCEKCertURL(productName: str, chip_id: bytes, reported_tcb: int) -> str: - # TODO add support for other product names - """Generate the VCEK certificate URL based on the product name, chip ID, and reported TCB""" - parts = TCBParts.from_int(reported_tcb) - base_url = "https://kds-proxy.tinfoil.sh/vcek/v1" - chip_id_hex = binascii.hexlify(chip_id).decode('ascii') - return f"{base_url}/{productName}/{chip_id_hex}?blSPL={parts.bl_spl}&teeSPL={parts.tee_spl}&snpSPL={parts.snp_spl}&ucodeSPL={parts.ucode_spl}" - -def _verify_report_signature(vcek: x509.Certificate, report: Report) -> bool: - """Verify the attestation report signature using VCEK's public key""" - - # Validate Report Format - POLICY_RESERVED_1_BIT = 17 - - if report.version < 2: - raise ValueError(f"Report version is lower than 2: is {report.version}") - - # Check reserved bit must be 1 - if not (report.policy & (1 << POLICY_RESERVED_1_BIT)): - raise ValueError(f"policy[{POLICY_RESERVED_1_BIT}] is reserved, must be 1, got 0") - - # Check bits 63-26 must be zero - if report.policy >> 26: - raise ValueError("policy bits 63-26 must be zero") - - try: - # Check signature algorithm - if report.signature_algo != 1: # 1 = SignEcdsaP384Sha384 - print(f"Unknown SignatureAlgo: {report.signature_algo}") - return False - - # Verify the public key is an EC key - public_key = vcek.public_key() - if not isinstance(public_key, ec.EllipticCurvePublicKey): - print("VCEK doesn't contain an EC public key") - return False - - # Convert the raw signature to DER format - # The signature in the report is in raw R||S format in AMD's little-endian format - # Each component is 72 bytes (0x48) for P384 - r_bytes = bytes(reversed(report.signature[0:0x48])) # Reverse bytes for big-endian - s_bytes = bytes(reversed(report.signature[0x48:0x90])) # Reverse bytes for big-endian - - r = int.from_bytes(r_bytes.lstrip(b'\x00'), byteorder='big') - s = int.from_bytes(s_bytes.lstrip(b'\x00'), byteorder='big') - - der_signature = utils.encode_dss_signature(r, s) - - # Verify signature - public_key.verify( - der_signature, - report.signed_data, - ec.ECDSA(hashes.SHA384()) - ) - return True - except Exception as e: - print(f"Attestation signature verification failed: {e}") - return False - -def verify_attestation(chain: CertificateChain, report: Report) -> bool: - """Verify attestation report with the certificate chain""" - try: - # Verify certificate chain - if not chain.verify_chain(): - # Since verify_chain() already prints its own error messages - return False - - # Verify report - if not _verify_report_signature(chain.vcek, report): - return False - - return True - - except Exception as e: - print(f"Verification failed: {e}") - return False - -def main(): - import argparse - - parser = argparse.ArgumentParser(description="Simplified AMD SEV-SNP Attestation Verifier") - parser.add_argument("--ark", required=True, help="Path to ARK certificate") - parser.add_argument("--ask", required=True, help="Path to ASK certificate") - parser.add_argument("--vcek", required=True, help="Path to VCEK certificate") - parser.add_argument("--report", required=True, help="Path to attestation report") - - args = parser.parse_args() - - # Load certificate chain - chain = CertificateChain.from_files(args.ark, args.ask, args.vcek) - - # Read and parse attestation report - with open(args.report, 'rb') as f: - report_data = f.read() - - report = Report(report_data) - - result = verify_attestation(chain, report) - if result: - print("Attestation verification successful") - return 0 - else: - print("Attestation verification failed") - return 1 - - -if __name__ == "__main__": - import sys - sys.exit(main()) \ No newline at end of file From 34ed960935e62325511d642470a27e2f7bb113f2 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 16:30:57 -0500 Subject: [PATCH 16/19] fix: assert RTMR count before zip comparison, fix stale comment --- src/tinfoil/attestation/collateral_tdx.py | 3 +-- src/tinfoil/attestation/validate_tdx.py | 5 +++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/tinfoil/attestation/collateral_tdx.py b/src/tinfoil/attestation/collateral_tdx.py index 6a7bb50..0965229 100644 --- a/src/tinfoil/attestation/collateral_tdx.py +++ b/src/tinfoil/attestation/collateral_tdx.py @@ -1362,8 +1362,7 @@ def calculate_min_tcb_evaluation_data_number( if item.tcb_recovery_event_date >= cutoff_date: return item.tcb_evaluation_data_number - # If all numbers are too old, return the highest (most recent) one - # This is a fallback - in practice, at least the most recent should be recent + # Fail closed if all numbers are too old highest = sorted_numbers[-1] raise CollateralError( f"All TCB evaluation data numbers are older than {max_age_days} days. " diff --git a/src/tinfoil/attestation/validate_tdx.py b/src/tinfoil/attestation/validate_tdx.py index ac4bb0b..442eb26 100644 --- a/src/tinfoil/attestation/validate_tdx.py +++ b/src/tinfoil/attestation/validate_tdx.py @@ -355,6 +355,11 @@ def _validate_exact_byte_matches(quote: QuoteV4, options: PolicyOptions) -> None # RTMR checks if t.rtmrs is not None: + if len(body.rtmrs) != len(t.rtmrs): + raise TdxValidationError( + f"RTMR count mismatch: quote has {len(body.rtmrs)}, " + f"policy expects {len(t.rtmrs)}" + ) for i, (given, expected) in enumerate(zip(body.rtmrs, t.rtmrs)): if expected is not None: _byte_check(f"RTMR[{i}]", given, expected) From d19cc31ac4958a736098bee3f0ed9b0d215df07a Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 17:11:15 -0500 Subject: [PATCH 17/19] fix: reject unrecognized OIDs and missing fields in PCK TCB extension parsing --- src/tinfoil/attestation/pck_extensions.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/tinfoil/attestation/pck_extensions.py b/src/tinfoil/attestation/pck_extensions.py index acf86c1..749fd5d 100644 --- a/src/tinfoil/attestation/pck_extensions.py +++ b/src/tinfoil/attestation/pck_extensions.py @@ -239,6 +239,9 @@ def _parse_tcb(value_component) -> PckCertTCB: tcb_components = [0] * TCB_COMPONENTS_COUNT pce_svn = 0 cpu_svn = bytes(CPU_SVN_SIZE) + found_pce_svn = False + found_cpu_svn = False + found_components: set[int] = set() with _asn1_errors("TCB extension"): for item in tcb_seq: @@ -255,6 +258,7 @@ def _parse_tcb(value_component) -> PckCertTCB: f"PCE SVN value {val} out of uint16 range" ) pce_svn = val + found_pce_svn = True elif oid_str == OID_CPU_SVN.dotted_string: raw = bytes(item[1]) if len(raw) != CPU_SVN_SIZE: @@ -262,6 +266,7 @@ def _parse_tcb(value_component) -> PckCertTCB: f"CPU SVN has wrong size: expected {CPU_SVN_SIZE}, got {len(raw)}" ) cpu_svn = raw + found_cpu_svn = True elif (idx := _TCB_COMPONENT_OID_INDEX.get(oid_str)) is not None: val = int(item[1]) if val < 0 or val > 0xFF: @@ -269,5 +274,20 @@ def _parse_tcb(value_component) -> PckCertTCB: f"TCB component {idx + 1} value {val} out of byte range" ) tcb_components[idx] = val + found_components.add(idx) + else: + raise PckExtensionError( + f"Unrecognized OID in TCB extension: {oid_str}" + ) + + if not found_pce_svn: + raise PckExtensionError("PCE SVN not found in TCB extension") + if not found_cpu_svn: + raise PckExtensionError("CPU SVN not found in TCB extension") + if len(found_components) != TCB_COMPONENTS_COUNT: + missing = set(range(TCB_COMPONENTS_COUNT)) - found_components + raise PckExtensionError( + f"Missing TCB components: {sorted(i + 1 for i in missing)}" + ) return PckCertTCB(pce_svn=pce_svn, cpu_svn=cpu_svn, tcb_components=tcb_components) From da9274e50a70dbad063a96e07f6d847e2feb8bf3 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 17:35:17 -0500 Subject: [PATCH 18/19] refactor: move safe_gzip_decompress from types.py to utils.py --- src/tinfoil/attestation/attestation_sev.py | 3 +- src/tinfoil/attestation/attestation_tdx.py | 3 +- src/tinfoil/attestation/types.py | 31 ------------------- src/tinfoil/attestation/utils.py | 36 ++++++++++++++++++++++ 4 files changed, 40 insertions(+), 33 deletions(-) create mode 100644 src/tinfoil/attestation/utils.py diff --git a/src/tinfoil/attestation/attestation_sev.py b/src/tinfoil/attestation/attestation_sev.py index 932b0ed..30bb6a5 100644 --- a/src/tinfoil/attestation/attestation_sev.py +++ b/src/tinfoil/attestation/attestation_sev.py @@ -9,7 +9,8 @@ from typing import Optional from .abi_sev import TCBParts, SnpPolicy, SnpPlatformInfo, Report -from .types import Measurement, Verification, PredicateType, TLS_KEY_FP_SIZE, HPKE_KEY_SIZE, safe_gzip_decompress +from .types import Measurement, Verification, PredicateType, TLS_KEY_FP_SIZE, HPKE_KEY_SIZE +from .utils import safe_gzip_decompress from .verify_sev import verify_attestation, CertificateChain from .validate_sev import validate_report, ValidationOptions diff --git a/src/tinfoil/attestation/attestation_tdx.py b/src/tinfoil/attestation/attestation_tdx.py index 33817a7..2efd41e 100644 --- a/src/tinfoil/attestation/attestation_tdx.py +++ b/src/tinfoil/attestation/attestation_tdx.py @@ -29,7 +29,8 @@ MR_OWNER_SIZE, MR_OWNER_CONFIG_SIZE, ) -from .types import TLS_KEY_FP_SIZE, HPKE_KEY_SIZE, Measurement, Verification, PredicateType, HardwareMeasurement, HardwareMeasurementError, TDX_MRTD_IDX, TDX_RTMR0_IDX, TDX_REGISTER_COUNT, safe_gzip_decompress +from .types import TLS_KEY_FP_SIZE, HPKE_KEY_SIZE, Measurement, Verification, PredicateType, HardwareMeasurement, HardwareMeasurementError, TDX_MRTD_IDX, TDX_RTMR0_IDX, TDX_REGISTER_COUNT +from .utils import safe_gzip_decompress from .verify_tdx import ( verify_tdx_quote, TdxVerificationError, diff --git a/src/tinfoil/attestation/types.py b/src/tinfoil/attestation/types.py index f8a4186..bce4704 100644 --- a/src/tinfoil/attestation/types.py +++ b/src/tinfoil/attestation/types.py @@ -35,37 +35,6 @@ MULTIPLATFORM_RTMR1_IDX = 1 MULTIPLATFORM_RTMR2_IDX = 2 -# Shared decompression constants -MAX_DECOMPRESSED_SIZE = 10 * 1024 * 1024 # 10 MiB - - -def safe_gzip_decompress(data: bytes, max_size: int = MAX_DECOMPRESSED_SIZE) -> bytes: - """Decompress gzip data with a size limit to prevent gzip bombs. - - Args: - data: Gzip-compressed bytes - max_size: Maximum allowed decompressed size - - Returns: - Decompressed bytes - - Raises: - ValueError: If decompressed data exceeds max_size or decompression fails - """ - import gzip - import io - - try: - with gzip.GzipFile(fileobj=io.BytesIO(data)) as f: - result = f.read(max_size + 1) - except (OSError, EOFError) as e: - raise ValueError(f"Gzip decompression failed: {e}") from e - if len(result) > max_size: - raise ValueError( - f"Decompressed attestation exceeds maximum size ({max_size} bytes)" - ) - return result - # ============================================================================= # Predicate types diff --git a/src/tinfoil/attestation/utils.py b/src/tinfoil/attestation/utils.py new file mode 100644 index 0000000..ccc4ef3 --- /dev/null +++ b/src/tinfoil/attestation/utils.py @@ -0,0 +1,36 @@ +""" +Shared utility functions for attestation modules. + +This module has no intra-package dependencies, so any module +can import from it without risk of circular imports. +""" + +import gzip +import io + +MAX_DECOMPRESSED_SIZE = 10 * 1024 * 1024 # 10 MiB + + +def safe_gzip_decompress(data: bytes, max_size: int = MAX_DECOMPRESSED_SIZE) -> bytes: + """Decompress gzip data with a size limit to prevent gzip bombs. + + Args: + data: Gzip-compressed bytes + max_size: Maximum allowed decompressed size + + Returns: + Decompressed bytes + + Raises: + ValueError: If decompressed data exceeds max_size or decompression fails + """ + try: + with gzip.GzipFile(fileobj=io.BytesIO(data)) as f: + result = f.read(max_size + 1) + except (OSError, EOFError) as e: + raise ValueError(f"Gzip decompression failed: {e}") from e + if len(result) > max_size: + raise ValueError( + f"Decompressed attestation exceeds maximum size ({max_size} bytes)" + ) + return result From 90198df1b3dea920814834618cf9ebe8a7bdad01 Mon Sep 17 00:00:00 2001 From: Jules Drean Date: Mon, 23 Feb 2026 17:55:20 -0500 Subject: [PATCH 19/19] refactor: deduplicate DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER constant --- src/tinfoil/attestation/attestation_tdx.py | 15 +-------------- src/tinfoil/attestation/collateral_tdx.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/tinfoil/attestation/attestation_tdx.py b/src/tinfoil/attestation/attestation_tdx.py index 2efd41e..e99f75a 100644 --- a/src/tinfoil/attestation/attestation_tdx.py +++ b/src/tinfoil/attestation/attestation_tdx.py @@ -49,6 +49,7 @@ CollateralError, TdxCollateral, TcbLevel, + DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER, ) @@ -56,20 +57,6 @@ # Orchestration Constants # ============================================================================= -# Minimum required tcbEvaluationDataNumber. -# This prevents accepting collateral issued before critical security updates. -# See: https://www.intel.com/content/www/us/en/developer/topic-technology/software-security-guidance/trusted-computing-base-recovery-attestation.html -# -# This value should be set using: -# from tinfoil.attestation.collateral_tdx import calculate_min_tcb_evaluation_data_number -# min_num = calculate_min_tcb_evaluation_data_number() -# -# The function queries Intel PCS and returns the lowest tcbEvaluationDataNumber -# whose TCB recovery event date is within the last year. -# -# Current value 18 corresponds to TCB recovery event date 2024-11-12. -DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER = 18 - # Expected values for policy validation # TdAttributes: All zeros except SEPT_VE_DISABLE=1 EXPECTED_TD_ATTRIBUTES = bytes.fromhex("0000001000000000") diff --git a/src/tinfoil/attestation/collateral_tdx.py b/src/tinfoil/attestation/collateral_tdx.py index 0965229..21a9a6b 100644 --- a/src/tinfoil/attestation/collateral_tdx.py +++ b/src/tinfoil/attestation/collateral_tdx.py @@ -365,6 +365,20 @@ def _is_crl_fresh(crl: x509.CertificateRevocationList) -> bool: # Intel-defined field sizes (hex character counts) # ============================================================================= +# Minimum required tcbEvaluationDataNumber. +# This prevents accepting collateral issued before critical security updates. +# See: https://www.intel.com/content/www/us/en/developer/topic-technology/software-security-guidance/trusted-computing-base-recovery-attestation.html +# +# This value should ideally be set dynamically using: +# from tinfoil.attestation.collateral_tdx import calculate_min_tcb_evaluation_data_number +# min_num = calculate_min_tcb_evaluation_data_number() +# +# That function queries Intel PCS and returns the lowest tcbEvaluationDataNumber +# whose TCB recovery event date is within the last year. +# +# Current value 18 corresponds to TCB recovery event date 2024-11-12. +DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER = 18 + TDX_MRSIGNER_SIZE = 48 # TDX module MRSIGNER: 48 bytes QE_MRSIGNER_SIZE = 32 # QE enclave MRSIGNER: 32 bytes QE_ATTRIBUTES_SIZE = 16 # QE enclave attributes: 16 bytes @@ -2010,7 +2024,7 @@ class CollateralValidationResult: def validate_collateral( quote: "QuoteV4", pck_chain: "PCKCertificateChain", - min_tcb_evaluation_data_number: int = 18, + min_tcb_evaluation_data_number: int = DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER, ) -> CollateralValidationResult: """ Validate all collateral for a TDX quote.