diff --git a/.gitignore b/.gitignore index 2f8fa1f..0c558fb 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,10 @@ cython_debug/ tuf-repo-cdn.sigstore.dev.json verifier/ tinfoil/tinfoil_verifier/ -.DS_Store \ No newline at end of file +.DS_Store + +# Logs +*.log + +# Generated documentation +docs/_build/ diff --git a/pyproject.toml b/pyproject.toml index 0682a33..b1e6fa0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "requests>=2.31.0", "cryptography>=42.0.0", "pyOpenSSL>=25.0.0", + "pyasn1>=0.4.0", "sigstore>=4.1.0", "platformdirs>=4.2.0", "pytest-asyncio>=0.26.0" diff --git a/src/tinfoil/attestation/__init__.py b/src/tinfoil/attestation/__init__.py index 5697fd3..7e18196 100644 --- a/src/tinfoil/attestation/__init__.py +++ b/src/tinfoil/attestation/__init__.py @@ -1,17 +1,39 @@ +from .types import ( + PredicateType, + TDX_TYPES, + Measurement, + Verification, + HardwareMeasurement, + AttestationError, + FormatMismatchError, + MeasurementMismatchError, + Rtmr3NotZeroError, + HardwareMeasurementError, + RTMR3_ZERO, +) from .attestation import ( fetch_attestation, verify_attestation_json, - verify_sev_attestation_v2, - Measurement, - PredicateType, - from_snp_digest ) +from .attestation_tdx import verify_tdx_attestation_v2, TdxAttestationError, verify_tdx_hardware +from .attestation_sev import verify_sev_attestation_v2, SevAttestationError __all__ = [ 'fetch_attestation', 'verify_sev_attestation_v2', + 'verify_tdx_attestation_v2', + 'verify_tdx_hardware', 'verify_attestation_json', 'Measurement', + 'Verification', 'PredicateType', - 'from_snp_digest' -] \ No newline at end of file + 'RTMR3_ZERO', + 'AttestationError', + 'FormatMismatchError', + 'MeasurementMismatchError', + 'Rtmr3NotZeroError', + 'HardwareMeasurementError', + 'HardwareMeasurement', + 'TdxAttestationError', + 'SevAttestationError', +] diff --git a/src/tinfoil/attestation/abi_sevsnp.py b/src/tinfoil/attestation/abi_sev.py similarity index 100% rename from src/tinfoil/attestation/abi_sevsnp.py rename to src/tinfoil/attestation/abi_sev.py diff --git a/src/tinfoil/attestation/abi_tdx.py b/src/tinfoil/attestation/abi_tdx.py new file mode 100644 index 0000000..e91482e --- /dev/null +++ b/src/tinfoil/attestation/abi_tdx.py @@ -0,0 +1,758 @@ +""" +TDX Quote parsing structures and constants. + +This module provides data structures and parsing logic for Intel TDX attestation +quotes in the QuoteV4 format. +""" + +import struct +from dataclasses import dataclass +from typing import List + +# ============================================================================= +# Constants +# ============================================================================= + +# Quote structure sizes +QUOTE_MIN_SIZE = 0x3FC # 1020 bytes minimum +HEADER_SIZE = 0x30 # 48 bytes +TD_QUOTE_BODY_SIZE = 0x248 # 584 bytes +QE_REPORT_SIZE = 0x180 # 384 bytes + +# Quote versions +QUOTE_VERSION_V4 = 4 +QUOTE_VERSION_V5 = 5 + +# TEE type for TDX +TEE_TDX = 0x00000081 + +# Attestation key type (ECDSA-256-with-P-256 curve) +ATTESTATION_KEY_TYPE_ECDSA_P256 = 2 + +# Certification data types +CERT_DATA_TYPE_PCK_CERT_CHAIN = 5 +CERT_DATA_TYPE_QE_REPORT = 6 + +# Field sizes +TEE_TCB_SVN_SIZE = 0x10 # 16 bytes +MR_SEAM_SIZE = 0x30 # 48 bytes +MR_SIGNER_SEAM_SIZE = 0x30 # 48 bytes +SEAM_ATTRIBUTES_SIZE = 0x08 # 8 bytes +TD_ATTRIBUTES_SIZE = 0x08 # 8 bytes +XFAM_SIZE = 0x08 # 8 bytes +MR_TD_SIZE = 0x30 # 48 bytes +MR_CONFIG_ID_SIZE = 0x30 # 48 bytes +MR_OWNER_SIZE = 0x30 # 48 bytes +MR_OWNER_CONFIG_SIZE = 0x30 # 48 bytes +RTMR_SIZE = 0x30 # 48 bytes +RTMR_COUNT = 4 +REPORT_DATA_SIZE = 0x40 # 64 bytes +QE_VENDOR_ID_SIZE = 0x10 # 16 bytes +USER_DATA_SIZE = 0x14 # 20 bytes +SIGNATURE_SIZE = 0x40 # 64 bytes +ATTESTATION_KEY_SIZE = 0x40 # 64 bytes +ECDSA_P256_COMPONENT_SIZE = 0x20 # 32 bytes per R or S component +SHA256_HASH_SIZE = 0x20 # 32 bytes +CERT_DATA_HEADER_SIZE = 6 # 2 bytes type + 4 bytes size +PCK_CERT_CHAIN_COUNT = 3 # leaf, intermediate, root + +# Intel QE Vendor ID: 939a7233-f79c-4ca9-940a-0db3957f0607 +INTEL_QE_VENDOR_ID = bytes.fromhex("939a7233f79c4ca9940a0db3957f0607") + +# ============================================================================= +# Header offsets (relative to quote start) +# ============================================================================= + +HEADER_VERSION_START = 0x00 +HEADER_VERSION_END = 0x02 +HEADER_AK_TYPE_START = 0x02 +HEADER_AK_TYPE_END = 0x04 +HEADER_TEE_TYPE_START = 0x04 +HEADER_TEE_TYPE_END = 0x08 +# Bytes 0x08-0x0C are reserved in QuoteV4. +# Note: Some older specs labeled these as QE_SVN/PCE_SVN, but they are +# always zero in practice. The actual SVN values come from: +# - PCE SVN: PCK certificate extensions (OID 1.2.840.113741.1.13.1.2.17) +# - QE ISV SVN: QE Report at offset 0x102 within certification data +HEADER_RESERVED1_START = 0x08 +HEADER_RESERVED1_END = 0x0C +HEADER_QE_VENDOR_ID_START = 0x0C +HEADER_QE_VENDOR_ID_END = 0x1C +HEADER_USER_DATA_START = 0x1C +HEADER_USER_DATA_END = 0x30 + +# ============================================================================= +# TdQuoteBody offsets (relative to body start at 0x30) +# ============================================================================= + +TD_TEE_TCB_SVN_START = 0x00 +TD_TEE_TCB_SVN_END = 0x10 +TD_MR_SEAM_START = 0x10 +TD_MR_SEAM_END = 0x40 +TD_MR_SIGNER_SEAM_START = 0x40 +TD_MR_SIGNER_SEAM_END = 0x70 +TD_SEAM_ATTRIBUTES_START = 0x70 +TD_SEAM_ATTRIBUTES_END = 0x78 +TD_ATTRIBUTES_START = 0x78 +TD_ATTRIBUTES_END = 0x80 +TD_XFAM_START = 0x80 +TD_XFAM_END = 0x88 +TD_MR_TD_START = 0x88 +TD_MR_TD_END = 0xB8 +TD_MR_CONFIG_ID_START = 0xB8 +TD_MR_CONFIG_ID_END = 0xE8 +TD_MR_OWNER_START = 0xE8 +TD_MR_OWNER_END = 0x118 +TD_MR_OWNER_CONFIG_START = 0x118 +TD_MR_OWNER_CONFIG_END = 0x148 +TD_RTMRS_START = 0x148 +TD_RTMRS_END = 0x208 +TD_REPORT_DATA_START = 0x208 +TD_REPORT_DATA_END = 0x248 + +# ============================================================================= +# Quote-level offsets +# ============================================================================= + +QUOTE_HEADER_START = 0x00 +QUOTE_HEADER_END = 0x30 +QUOTE_BODY_START = 0x30 +QUOTE_BODY_END = 0x278 +QUOTE_SIGNED_DATA_SIZE_START = 0x278 +QUOTE_SIGNED_DATA_SIZE_END = 0x27C +QUOTE_SIGNED_DATA_START = 0x27C + +# ============================================================================= +# SignedData offsets (relative to signed data start) +# ============================================================================= + +SIGNED_DATA_SIGNATURE_START = 0x00 +SIGNED_DATA_SIGNATURE_END = 0x40 +SIGNED_DATA_AK_START = 0x40 +SIGNED_DATA_AK_END = 0x80 +SIGNED_DATA_CERT_DATA_START = 0x80 + +# ============================================================================= +# QE Report offsets (within certification data) +# ============================================================================= + +QE_CPU_SVN_START = 0x00 +QE_CPU_SVN_END = 0x10 +QE_MISC_SELECT_START = 0x10 +QE_MISC_SELECT_END = 0x14 +QE_ATTRIBUTES_START = 0x30 +QE_ATTRIBUTES_END = 0x40 +QE_MR_ENCLAVE_START = 0x40 +QE_MR_ENCLAVE_END = 0x60 +QE_MR_SIGNER_START = 0x80 +QE_MR_SIGNER_END = 0xA0 +QE_ISV_PROD_ID_START = 0x100 +QE_ISV_PROD_ID_END = 0x102 +QE_ISV_SVN_START = 0x102 +QE_ISV_SVN_END = 0x104 +QE_REPORT_DATA_START = 0x140 +QE_REPORT_DATA_END = 0x180 + + +# ============================================================================= +# Data Structures +# ============================================================================= + +@dataclass +class TdxHeader: + """ + TDX Quote header (48 bytes). + + Contains quote metadata including version, attestation key type, + TEE type, and vendor information. + + Note: Bytes 8-11 are reserved. Some older specs labeled these as + QE_SVN/PCE_SVN, but they are always zero in practice. The actual + SVN values come from PCK certificate extensions and QE Report. + """ + version: int # 2 bytes - must be 4 for QuoteV4 + attestation_key_type: int # 2 bytes - must be 2 (ECDSA-P256) + tee_type: int # 4 bytes - must be 0x81 (TDX) + reserved: bytes # 4 bytes - reserved (was QE_SVN/PCE_SVN in older specs) + qe_vendor_id: bytes # 16 bytes - Intel: 939a7233-f79c-4ca9-940a-0db3957f0607 + user_data: bytes # 20 bytes - Custom data from QE + + def __str__(self) -> str: + return ( + f"TdxHeader(version={self.version}, " + f"ak_type={self.attestation_key_type}, " + f"tee_type=0x{self.tee_type:x}, " + f"qe_vendor_id={self.qe_vendor_id.hex()})" + ) + + +@dataclass +class TdQuoteBody: + """ + TD Quote Body (584 bytes). + + Contains the TD's measurements and report data. This is the core + attestation data signed by the QE. + """ + tee_tcb_svn: bytes # 16 bytes - TEE TCB Security Version Number + mr_seam: bytes # 48 bytes - Measurement of SEAM module + mr_signer_seam: bytes # 48 bytes - Signer of SEAM module (zeros for Intel SEAM) + seam_attributes: bytes # 8 bytes - SEAM attributes + td_attributes: bytes # 8 bytes - TD attributes + xfam: bytes # 8 bytes - Extended feature mask + mr_td: bytes # 48 bytes - Measurement of TD (MRTD) + mr_config_id: bytes # 48 bytes - Config ID + mr_owner: bytes # 48 bytes - Owner measurement + mr_owner_config: bytes # 48 bytes - Owner config measurement + rtmrs: List[bytes] # 4 x 48 bytes - Runtime measurement registers + report_data: bytes # 64 bytes - Custom data (TLS key FP + HPKE key) + + def __str__(self) -> str: + if len(self.rtmrs) == RTMR_COUNT: + rtmr_lines = "".join( + f" rtmr{i}={self.rtmrs[i].hex()},\n" for i in range(RTMR_COUNT) + ) + else: + rtmr_lines = f" rtmrs=({len(self.rtmrs)} entries),\n" + return ( + f"TdQuoteBody(\n" + f" mr_td={self.mr_td.hex()},\n" + f"{rtmr_lines}" + f" mr_seam={self.mr_seam.hex()},\n" + f" report_data={self.report_data.hex()}\n" + f")" + ) + + def get_measurements(self) -> List[bytes]: + """Return the 5 TDX measurements: [MRTD, RTMR0, RTMR1, RTMR2, RTMR3].""" + return [self.mr_td] + self.rtmrs + + +@dataclass +class QeReport: + """ + Quoting Enclave Report (384 bytes). + + SGX enclave report from the Quoting Enclave, used to verify + the attestation key is bound to a legitimate QE. + """ + cpu_svn: bytes # 16 bytes + misc_select: int # 4 bytes + attributes: bytes # 16 bytes + mr_enclave: bytes # 32 bytes - QE enclave measurement + mr_signer: bytes # 32 bytes - QE signer measurement + isv_prod_id: int # 2 bytes - Product ID + isv_svn: int # 2 bytes - Security Version Number + report_data: bytes # 64 bytes - Contains hash of attestation key + + def __str__(self) -> str: + return ( + f"QeReport(mr_enclave={self.mr_enclave.hex()}, " + f"mr_signer={self.mr_signer.hex()}, " + f"isv_prod_id={self.isv_prod_id}, " + f"isv_svn={self.isv_svn})" + ) + + +@dataclass +class QeReportCertificationData: + """ + QE Report Certification Data. + + Contains the QE report, its signature, authentication data, + and the nested PCK certificate chain. + """ + qe_report: bytes # 384 bytes - Raw QE report + qe_report_parsed: QeReport # Parsed QE report + qe_report_signature: bytes # 64 bytes - ECDSA signature over QE report + qe_auth_data: bytes # Variable - Authentication data + pck_cert_chain_data: "PckCertChainData" # Nested PCK certificate chain + + +@dataclass +class PckCertChainData: + """ + PCK Certificate Chain Data. + + Contains the certification data type and the actual certificate + chain in PEM format. + """ + cert_type: int # 2 bytes - Should be 5 (PCK cert chain) + cert_data_size: int # 4 bytes + cert_data: bytes # Variable - PEM certificate chain + + +@dataclass +class CertificationData: + """ + Certification Data from the quote (type 6 only). + + Contains QE report certification data with nested PCK cert chain. + Type 5 (direct PCK cert chain) is not supported as it lacks the + QE report needed for attestation key binding verification. + """ + cert_type: int # 2 bytes - Must be 6 (QE report certification data) + cert_data_size: int # 4 bytes - Size of certification data + qe_report_data: QeReportCertificationData + + def get_pck_chain(self) -> PckCertChainData: + """Get the PCK certificate chain from the QE report certification data.""" + return self.qe_report_data.pck_cert_chain_data + + +@dataclass +class SignedData: + """ + Signed Data section of the quote. + + Contains the quote signature, attestation public key, and + certification data (certificate chain). + """ + signature: bytes # 64 bytes - ECDSA signature (R || S) + attestation_key: bytes # 64 bytes - Raw ECDSA P-256 public key + certification_data: CertificationData + + def __str__(self) -> str: + return ( + f"SignedData(signature={self.signature[:8].hex()}..., " + f"attestation_key={self.attestation_key[:8].hex()}..., " + f"cert_type={self.certification_data.cert_type})" + ) + + +@dataclass +class QuoteV4: + """ + TDX Quote Version 4. + + The complete TDX attestation quote containing header, TD quote body, + and signed data with certification chain. + """ + header: TdxHeader + td_quote_body: TdQuoteBody + signed_data_size: int + signed_data: SignedData + extra_bytes: bytes = b"" # Any trailing bytes after signed data + + def __str__(self) -> str: + return ( + f"QuoteV4(\n" + f" header={self.header},\n" + f" td_quote_body={self.td_quote_body},\n" + f" signed_data_size={self.signed_data_size},\n" + f" signed_data={self.signed_data}\n" + f")" + ) + + def get_measurements(self) -> List[bytes]: + """Return the 5 TDX measurements: [MRTD, RTMR0, RTMR1, RTMR2, RTMR3].""" + return self.td_quote_body.get_measurements() + + def get_report_data(self) -> bytes: + """Return the 64-byte report data containing TLS key FP and HPKE key.""" + return self.td_quote_body.report_data + + +# ============================================================================= +# Parsing Functions +# ============================================================================= + + + +class TdxQuoteParseError(Exception): + """Raised when TDX quote parsing fails.""" + pass + + +def _parse_header(data: bytes) -> TdxHeader: + """ + Parse the 48-byte TDX quote header. + + Args: + data: 48 bytes of header data + + Returns: + Parsed TdxHeader + + Raises: + TdxQuoteParseError: If header is malformed + """ + if len(data) < HEADER_SIZE: + raise TdxQuoteParseError( + f"Header too short: {len(data)} bytes, expected {HEADER_SIZE}" + ) + + version = struct.unpack_from(" TdQuoteBody: + """ + Parse the 584-byte TD quote body. + + Args: + data: 584 bytes of TD quote body data + + Returns: + Parsed TdQuoteBody + + Raises: + TdxQuoteParseError: If body is malformed + """ + if len(data) < TD_QUOTE_BODY_SIZE: + raise TdxQuoteParseError( + f"TD quote body too short: {len(data)} bytes, expected {TD_QUOTE_BODY_SIZE}" + ) + + # Parse RTMRs (4 x 48 bytes) + rtmrs = [] + for i in range(RTMR_COUNT): + start = TD_RTMRS_START + (i * RTMR_SIZE) + end = start + RTMR_SIZE + rtmrs.append(data[start:end]) + + return TdQuoteBody( + tee_tcb_svn=data[TD_TEE_TCB_SVN_START:TD_TEE_TCB_SVN_END], + mr_seam=data[TD_MR_SEAM_START:TD_MR_SEAM_END], + mr_signer_seam=data[TD_MR_SIGNER_SEAM_START:TD_MR_SIGNER_SEAM_END], + seam_attributes=data[TD_SEAM_ATTRIBUTES_START:TD_SEAM_ATTRIBUTES_END], + td_attributes=data[TD_ATTRIBUTES_START:TD_ATTRIBUTES_END], + xfam=data[TD_XFAM_START:TD_XFAM_END], + mr_td=data[TD_MR_TD_START:TD_MR_TD_END], + mr_config_id=data[TD_MR_CONFIG_ID_START:TD_MR_CONFIG_ID_END], + mr_owner=data[TD_MR_OWNER_START:TD_MR_OWNER_END], + mr_owner_config=data[TD_MR_OWNER_CONFIG_START:TD_MR_OWNER_CONFIG_END], + rtmrs=rtmrs, + report_data=data[TD_REPORT_DATA_START:TD_REPORT_DATA_END], + ) + + +def _parse_qe_report(data: bytes) -> QeReport: + """ + Parse the 384-byte QE report. + + Args: + data: 384 bytes of QE report data + + Returns: + Parsed QeReport + + Raises: + TdxQuoteParseError: If report is malformed + """ + if len(data) < QE_REPORT_SIZE: + raise TdxQuoteParseError( + f"QE report too short: {len(data)} bytes, expected {QE_REPORT_SIZE}" + ) + + return QeReport( + cpu_svn=data[QE_CPU_SVN_START:QE_CPU_SVN_END], + misc_select=struct.unpack_from(" tuple[PckCertChainData, int]: + """ + Parse PCK certificate chain data. + + Args: + data: Raw bytes starting at PCK cert chain data + + Returns: + Tuple of (PckCertChainData, bytes_consumed) + + Raises: + TdxQuoteParseError: If data is malformed + """ + if len(data) < CERT_DATA_HEADER_SIZE: + raise TdxQuoteParseError("PCK cert chain data too short for header") + + cert_type = struct.unpack_from(" tuple[QeReportCertificationData, int]: + """ + Parse QE report certification data (type 6). + + Structure: + - QE Report: 384 bytes + - QE Report Signature: 64 bytes + - QE Auth Data Size: 2 bytes + - QE Auth Data: variable + - PCK Cert Chain Data: variable (nested type 5) + + Args: + data: Raw bytes starting at QE report certification data + + Returns: + Tuple of (QeReportCertificationData, bytes_consumed) + + Raises: + TdxQuoteParseError: If data is malformed + """ + offset = 0 + + # QE Report (384 bytes) + if len(data) < offset + QE_REPORT_SIZE: + raise TdxQuoteParseError("Data too short for QE report") + qe_report_raw = data[offset:offset + QE_REPORT_SIZE] + qe_report_parsed = _parse_qe_report(qe_report_raw) + offset += QE_REPORT_SIZE + + # QE Report Signature (64 bytes) + if len(data) < offset + SIGNATURE_SIZE: + raise TdxQuoteParseError("Data too short for QE report signature") + qe_report_signature = data[offset:offset + SIGNATURE_SIZE] + offset += SIGNATURE_SIZE + + # QE Auth Data Size (2 bytes) + Auth Data + if len(data) < offset + 2: + raise TdxQuoteParseError("Data too short for QE auth data size") + qe_auth_data_size = struct.unpack_from(" tuple[CertificationData, int]: + """ + Parse certification data from signed data section. + + Only type 6 (QE report certification data) is supported. Type 5 (direct + PCK cert chain) is rejected as it lacks the QE report needed for + attestation key binding verification. + + Args: + data: Raw bytes starting at certification data + + Returns: + Tuple of (CertificationData, bytes_consumed) + + Raises: + TdxQuoteParseError: If data is malformed or unsupported type + """ + if len(data) < CERT_DATA_HEADER_SIZE: + raise TdxQuoteParseError("Certification data too short for header") + + cert_type = struct.unpack_from(" SignedData: + """ + Parse the signed data section of the quote. + + Structure: + - Signature: 64 bytes (ECDSA R || S) + - Attestation Key: 64 bytes (raw P-256 public key) + - Certification Data: variable + + Args: + data: Raw bytes of signed data section + + Returns: + Parsed SignedData + + Raises: + TdxQuoteParseError: If data is malformed + """ + min_size = SIGNATURE_SIZE + ATTESTATION_KEY_SIZE + CERT_DATA_HEADER_SIZE + if len(data) < min_size: + raise TdxQuoteParseError( + f"Signed data too short: {len(data)} bytes, minimum {min_size}" + ) + + signature = data[SIGNED_DATA_SIGNATURE_START:SIGNED_DATA_SIGNATURE_END] + attestation_key = data[SIGNED_DATA_AK_START:SIGNED_DATA_AK_END] + + certification_data, _ = _parse_certification_data(data[SIGNED_DATA_CERT_DATA_START:]) + + return SignedData( + signature=signature, + attestation_key=attestation_key, + certification_data=certification_data, + ) + + +def _validate_header(header: TdxHeader) -> None: + """ + Validate TDX quote header fields. + + Args: + header: Parsed header to validate + + Raises: + TdxQuoteParseError: If validation fails + """ + if header.version == QUOTE_VERSION_V5: + raise TdxQuoteParseError( + "TDX QuoteV5 is not supported. Only QuoteV4 is implemented." + ) + + if header.version != QUOTE_VERSION_V4: + raise TdxQuoteParseError( + f"Unsupported quote version: {header.version}. Expected {QUOTE_VERSION_V4}." + ) + + if header.attestation_key_type != ATTESTATION_KEY_TYPE_ECDSA_P256: + raise TdxQuoteParseError( + f"Unsupported attestation key type: {header.attestation_key_type}. " + f"Expected {ATTESTATION_KEY_TYPE_ECDSA_P256} (ECDSA-P256)." + ) + + if header.tee_type != TEE_TDX: + raise TdxQuoteParseError( + f"Invalid TEE type: 0x{header.tee_type:x}. Expected 0x{TEE_TDX:x} (TDX)." + ) + + if header.qe_vendor_id != INTEL_QE_VENDOR_ID: + raise TdxQuoteParseError( + f"Unknown QE vendor ID: {header.qe_vendor_id.hex()}. " + f"Expected Intel QE: {INTEL_QE_VENDOR_ID.hex()}" + ) + + +def parse_quote(data: bytes) -> QuoteV4: + """ + Parse a TDX attestation quote from raw bytes. + + This is the main entry point for TDX quote parsing. It handles QuoteV4 + format and explicitly rejects QuoteV5. + + Args: + data: Raw quote bytes (typically from base64-decoded, gzip-decompressed + attestation document) + + Returns: + Parsed QuoteV4 structure + + Raises: + TdxQuoteParseError: If parsing fails or quote format is unsupported + + Example: + >>> raw_quote = base64.b64decode(attestation_doc) + >>> decompressed = gzip.decompress(raw_quote) + >>> quote = parse_quote(decompressed) + >>> measurements = quote.get_measurements() + """ + if len(data) < QUOTE_MIN_SIZE: + raise TdxQuoteParseError( + f"Quote too short: {len(data)} bytes, minimum {QUOTE_MIN_SIZE}" + ) + + # Parse header first to check version + header = _parse_header(data[QUOTE_HEADER_START:QUOTE_HEADER_END]) + _validate_header(header) + + # Parse TD quote body + td_quote_body = _parse_td_quote_body(data[QUOTE_BODY_START:QUOTE_BODY_END]) + + # Get signed data size and parse signed data + signed_data_size = struct.unpack_from( + " signed_data_end else b"" + + return QuoteV4( + header=header, + td_quote_body=td_quote_body, + signed_data_size=signed_data_size, + signed_data=signed_data, + extra_bytes=extra_bytes, + ) diff --git a/src/tinfoil/attestation/attestation.py b/src/tinfoil/attestation/attestation.py index 2de3c12..7e2f3b5 100644 --- a/src/tinfoil/attestation/attestation.py +++ b/src/tinfoil/attestation/attestation.py @@ -1,76 +1,24 @@ -from dataclasses import dataclass -from enum import Enum -import json - -import base64 -import gzip import hashlib +import json import ssl -from typing import List, Optional +from dataclasses import dataclass + import requests from cryptography import x509 from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ec -from .validate import validate_report, ValidationOptions -from .verify import Report, verify_attestation, CertificateChain -from .abi_sevsnp import TCBParts, SnpPolicy, SnpPlatformInfo - - -class PredicateType(str, Enum): - """Predicate types for attestation""" - SEV_GUEST_V1 = "https://tinfoil.sh/predicate/sev-snp-guest/v1" # Deprecated - SEV_GUEST_V2 = "https://tinfoil.sh/predicate/sev-snp-guest/v2" - TDX_GUEST_V1 = "https://tinfoil.sh/predicate/tdx-guest/v1" # Deprecated - SNP_TDX_MULTIPLATFORM_v1 = "https://tinfoil.sh/predicate/snp-tdx-multiplatform/v1" +from .types import ( + PredicateType, + Measurement, + Verification, +) +from .attestation_tdx import verify_tdx_attestation_v2 +from .attestation_sev import verify_sev_attestation_v2 ATTESTATION_ENDPOINT = "/.well-known/tinfoil-attestation" +REQUEST_TIMEOUT_SECONDS = 15 -class AttestationError(Exception): - """Base class for attestation errors""" - pass - -class FormatMismatchError(AttestationError): - """Raised when attestation formats don't match""" - pass - -class MeasurementMismatchError(AttestationError): - """Raised when measurements don't match""" - pass - -@dataclass -class Measurement: - """Represents measurement data""" - type: PredicateType - registers: List[str] - - def fingerprint(self) -> str: - """ - Computes the SHA-256 hash of all measurements, - or returns the single measurement if there is only one - """ - if len(self.registers) == 1: - return self.registers[0] - - all_data = str(self.type) + "".join(self.registers) - return hashlib.sha256(all_data.encode()).hexdigest() - - def equals(self, other: 'Measurement') -> None: - """ - Checks if this measurement equals another measurement - Raises appropriate error if they don't match - """ - if self.type != other.type: - raise FormatMismatchError() - if len(self.registers) != len(other.registers) or self.registers != other.registers: - raise MeasurementMismatchError() - -@dataclass -class Verification: - """Represents verification results""" - measurement: Measurement - public_key_fp: str - hpke_public_key: Optional[str] = None @dataclass class Document: @@ -90,12 +38,21 @@ def verify(self) -> Verification: """ if self.format == PredicateType.SEV_GUEST_V2: return verify_sev_attestation_v2(self.body) + elif self.format == PredicateType.TDX_GUEST_V2: + return verify_tdx_attestation_v2(self.body) else: raise ValueError(f"Unsupported attestation format: {self.format}") def verify_attestation_json(json_data: bytes) -> Verification: """Verifies an attestation document in JSON format and returns the inner measurements""" - doc_dict = json.loads(json_data) + try: + doc_dict = json.loads(json_data) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise ValueError(f"Invalid attestation JSON: {e}") from e + + if not isinstance(doc_dict, dict) or "format" not in doc_dict or "body" not in doc_dict: + raise ValueError("Attestation JSON must contain 'format' and 'body' fields") + doc = Document( format=PredicateType(doc_dict["format"]), body=doc_dict["body"] @@ -130,129 +87,14 @@ def connection_cert_fp(ssl_socket: ssl.SSLSocket) -> str: def fetch_attestation(host: str) -> Document: """Retrieves the attestation document from a given enclave hostname""" url = f"https://{host}{ATTESTATION_ENDPOINT}" - response = requests.get(url) + response = requests.get(url, timeout=REQUEST_TIMEOUT_SECONDS) response.raise_for_status() doc_dict = response.json() + if not isinstance(doc_dict, dict) or "format" not in doc_dict or "body" not in doc_dict: + raise ValueError(f"Invalid attestation response from {host}: missing 'format' or 'body'") + return Document( format=PredicateType(doc_dict["format"]), body=doc_dict["body"] ) - -min_tcb = TCBParts( - bl_spl=0x7, - tee_spl=0, - snp_spl=0xe, - ucode_spl=0x48, -) - -default_validation_options = ValidationOptions( - guest_policy=SnpPolicy( - abi_minor=0, - abi_major=0, - smt=True, - migrate_ma=False, - debug=False, - single_socket=False, - cxl_allowed=False, - mem_aes256_xts=False, - rapl_dis=False, - ciphertext_hiding_dram=False, - page_swap_disabled=False, - ), - minimum_guest_svn=0, - minimum_build=21, - minimum_version=(1 << 8) | 55, # 1.55 - minimum_tcb=min_tcb, - minimum_launch_tcb=min_tcb, - permit_provisional_firmware=False, # We only support False per your requirement - platform_info=SnpPlatformInfo( - smt_enabled=True, - tsme_enabled=True, - ecc_enabled=False, - rapl_disabled=False, - ciphertext_hiding_dram_enabled=False, - alias_check_complete=False, - tio_enabled=False, - ), - require_author_key=False, - require_id_block=False, -) - -def verify_sev_attestation_v2(attestation_doc: str) -> Verification: - """Verify SEV attestation document and return verification result.""" - report = verify_sev_report(attestation_doc, True) - - # Create measurement object - measurement = Measurement( - type=PredicateType.SEV_GUEST_V2, - registers=[ - report.measurement.hex() - ] - ) - - keys = report.report_data - tls_key_fp = keys[0:32] - hpke_public_key = keys[32:64] - - return Verification( - measurement=measurement, - public_key_fp=tls_key_fp.hex(), - hpke_public_key=hpke_public_key.hex() - ) - - -def verify_sev_report(attestation_doc: str, is_compressed: bool) -> Report: - """Verify SEV attestation document and return verification result.""" - try: - att_doc_bytes = base64.b64decode(attestation_doc) - except Exception as e: - raise ValueError(f"Failed to decode base64: {e}") - - if is_compressed: - att_doc_bytes = gzip.decompress(att_doc_bytes) - - # Parse the report - try: - report = Report(att_doc_bytes) - except Exception as e: - raise ValueError(f"Failed to parse report: {e}") - - # Get attestation chain - chain: CertificateChain = CertificateChain.from_report(report) - - # Verify attestation - try: - res = verify_attestation(chain, report) - except Exception as e: - raise ValueError(f"Failed to verify attestation: {e}") - - if not res: - raise ValueError("Attestation verification failed!") - - # Validate report - try: - validate_report(report, chain, default_validation_options) - except Exception as e: - raise ValueError(f"Failed to validate report: {e}") - - return report - -def from_snp_digest(snp_digest: str) -> dict: - """ - Convert an SNP launch digest string to measurement format. - - Args: - snp_digest: The SNP launch digest as a hex string - - Returns: - Dictionary in the format expected by SecureClient measurement parameter - - Example: - from tinfoil.attestation import from_snp_digest - measurement = from_snp_digest("abcdef") - client = TinfoilAI(measurement=measurement) - """ - return { - "snp_measurement": snp_digest - } diff --git a/src/tinfoil/attestation/attestation_sev.py b/src/tinfoil/attestation/attestation_sev.py new file mode 100644 index 0000000..30bb6a5 --- /dev/null +++ b/src/tinfoil/attestation/attestation_sev.py @@ -0,0 +1,160 @@ +""" +AMD SEV-SNP Attestation Orchestration Module. + +This module provides the high-level entry point for AMD SEV-SNP attestation verification. + +""" + +import base64 +from typing import Optional + +from .abi_sev import TCBParts, SnpPolicy, SnpPlatformInfo, Report +from .types import Measurement, Verification, PredicateType, TLS_KEY_FP_SIZE, HPKE_KEY_SIZE +from .utils import safe_gzip_decompress +from .verify_sev import verify_attestation, CertificateChain +from .validate_sev import validate_report, ValidationOptions + + +class SevAttestationError(Exception): + """Raised when SEV-SNP attestation verification fails.""" + pass + + +# ============================================================================= +# Orchestration Constants +# ============================================================================= + +# Minimum TCB requirements for AMD SEV-SNP +min_tcb = TCBParts( + bl_spl=0x7, + tee_spl=0, + snp_spl=0xe, + ucode_spl=0x48, +) + +# Default validation options for AMD SEV-SNP attestation +default_validation_options = ValidationOptions( + guest_policy=SnpPolicy( + abi_minor=0, + abi_major=0, + smt=True, + migrate_ma=False, + debug=False, + single_socket=False, + cxl_allowed=False, + mem_aes256_xts=False, + rapl_dis=False, + ciphertext_hiding_dram=False, + page_swap_disabled=False, + ), + minimum_guest_svn=0, + minimum_build=21, + minimum_version=(1 << 8) | 55, # 1.55 + minimum_tcb=min_tcb, + minimum_launch_tcb=min_tcb, + permit_provisional_firmware=False, + platform_info=SnpPlatformInfo( + smt_enabled=True, + tsme_enabled=False, + ecc_enabled=False, + rapl_disabled=False, + ciphertext_hiding_dram_enabled=False, + alias_check_complete=False, + tio_enabled=False, + ), + require_author_key=False, + require_id_block=False, +) + + +# ============================================================================= +# Main Entry Points +# ============================================================================= + +def verify_sev_attestation_v2(attestation_doc: str) -> Verification: + """Verify SEV attestation document (v2 format) and return verification result. + + Raises: + ValueError: If verification fails + """ + try: + report = verify_sev_report(attestation_doc) + except SevAttestationError as e: + raise ValueError(f"SEV attestation verification failed: {e}") from e + + # Create measurement object + measurement = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=[ + report.measurement.hex() + ] + ) + + keys = report.report_data + required_len = TLS_KEY_FP_SIZE + HPKE_KEY_SIZE + if len(keys) < required_len: + raise ValueError( + f"report_data too short: {len(keys)} bytes, need at least {required_len}" + ) + tls_key_fp = keys[0:TLS_KEY_FP_SIZE] + hpke_public_key = keys[TLS_KEY_FP_SIZE:TLS_KEY_FP_SIZE + HPKE_KEY_SIZE] + + return Verification( + measurement=measurement, + public_key_fp=tls_key_fp.hex(), + hpke_public_key=hpke_public_key.hex() + ) + + +def verify_sev_report( + attestation_doc: str, + is_compressed: bool = True, + validation_options: Optional[ValidationOptions] = None, +) -> Report: + """Verify SEV attestation document and return the parsed report. + + Args: + attestation_doc: Base64-encoded attestation document + is_compressed: Whether the document is gzip-compressed (default True) + validation_options: Custom validation options; uses module defaults if None + """ + options = validation_options if validation_options is not None else default_validation_options + + try: + att_doc_bytes = base64.b64decode(attestation_doc) + except Exception as e: + raise SevAttestationError(f"Failed to decode base64: {e}") from e + + if is_compressed: + try: + att_doc_bytes = safe_gzip_decompress(att_doc_bytes) + except SevAttestationError: + raise + except Exception as e: + raise SevAttestationError(f"Failed to decompress attestation document: {e}") from e + + try: + report = Report(att_doc_bytes) + except Exception as e: + raise SevAttestationError(f"Failed to parse report: {e}") from e + + try: + chain = CertificateChain.from_report(report) + except Exception as e: + raise SevAttestationError(f"Failed to build certificate chain: {e}") from e + + try: + res = verify_attestation(chain, report) + except Exception as e: + raise SevAttestationError(f"Failed to verify attestation: {e}") from e + + if not res: + raise SevAttestationError("Attestation verification failed!") + + try: + validate_report(report, chain, options) + except Exception as e: + raise SevAttestationError(f"Failed to validate report: {e}") from e + + return report + diff --git a/src/tinfoil/attestation/attestation_tdx.py b/src/tinfoil/attestation/attestation_tdx.py new file mode 100644 index 0000000..e99f75a --- /dev/null +++ b/src/tinfoil/attestation/attestation_tdx.py @@ -0,0 +1,317 @@ +""" +TDX Attestation Orchestration Module. + +This module provides the high-level entry point for TDX attestation verification. +It coordinates between: +- Quote parsing (abi_tdx) +- Cryptographic verification (verify_tdx) +- Policy validation (validate_tdx) +- Collateral validation (collateral_tdx) + +Usage: + from tinfoil.attestation.attestation_tdx import verify_tdx_attestation + + result = verify_tdx_attestation(attestation_doc) + # result.measurements contains the 5 TDX measurements + # result.tls_key_fp contains the TLS key fingerprint +""" + +import base64 +from dataclasses import dataclass, field +from typing import Optional + +from .abi_tdx import ( + parse_quote, + QuoteV4, + TdxQuoteParseError, + INTEL_QE_VENDOR_ID, + MR_CONFIG_ID_SIZE, + MR_OWNER_SIZE, + MR_OWNER_CONFIG_SIZE, +) +from .types import TLS_KEY_FP_SIZE, HPKE_KEY_SIZE, Measurement, Verification, PredicateType, HardwareMeasurement, HardwareMeasurementError, TDX_MRTD_IDX, TDX_RTMR0_IDX, TDX_REGISTER_COUNT +from .utils import safe_gzip_decompress +from .verify_tdx import ( + verify_tdx_quote, + TdxVerificationError, + PCKCertificateChain, +) +from .validate_tdx import ( + validate_tdx_policy, + PolicyOptions, + TdQuoteBodyOptions, + HeaderOptions, + TdxValidationError, +) +from .pck_extensions import PckExtensions +from .collateral_tdx import ( + validate_collateral, + CollateralError, + TdxCollateral, + TcbLevel, + DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER, +) + + +# ============================================================================= +# Orchestration Constants +# ============================================================================= + +# Expected values for policy validation +# TdAttributes: All zeros except SEPT_VE_DISABLE=1 +EXPECTED_TD_ATTRIBUTES = bytes.fromhex("0000001000000000") +# XFam: Enable FP, SSE, AVX, AVX512, PK, AMX +EXPECTED_XFAM = bytes.fromhex("e702060000000000") +# MinimumTeeTcbSvn: 3.1.2 +EXPECTED_MINIMUM_TEE_TCB_SVN = bytes.fromhex("03010200000000000000000000000000") + +# Accepted MR_SEAM values from Intel TDX module releases +# https://github.com/intel/confidential-computing.tdx.tdx-module/releases +ACCEPTED_MR_SEAMS: tuple[bytes, ...] = ( + bytes.fromhex("476a2997c62bccc78370913d0a80b956e3721b24272bc66c4d6307ced4be2865c40e26afac75f12df3425b03eb59ea7c"), # TDX Module 2.0.08 + bytes.fromhex("7bf063280e94fb051f5dd7b1fc59ce9aac42bb961df8d44b709c9b0ff87a7b4df648657ba6d1189589feab1d5a3c9a9d"), # TDX Module 1.5.16 + bytes.fromhex("685f891ea5c20e8fa27b151bf34bf3b50fbaf7143cc53662727cbdb167c0ad8385f1f6f3571539a91e104a1c96d75e04"), # TDX Module 2.0.02 + bytes.fromhex("49b66faa451d19ebbdbe89371b8daf2b65aa3984ec90110343e9e2eec116af08850fa20e3b1aa9a874d77a65380ee7e6"), # TDX Module 1.5.08 +) + + +# ============================================================================= +# Orchestration Config +# ============================================================================= + +@dataclass(frozen=True) +class TdxVerificationConfig: + """ + Configuration for TDX attestation verification. + + All fields have sensible defaults matching the current hardcoded values. + Override individual fields to customise verification policy. + """ + min_tcb_evaluation_data_number: int = DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER + accepted_mr_seams: tuple[bytes, ...] = ACCEPTED_MR_SEAMS + policy_options: Optional[PolicyOptions] = None + expected_td_attributes: bytes = EXPECTED_TD_ATTRIBUTES + expected_xfam: bytes = EXPECTED_XFAM + expected_minimum_tee_tcb_svn: bytes = EXPECTED_MINIMUM_TEE_TCB_SVN + + +_DEFAULT_CONFIG = TdxVerificationConfig() + + +# ============================================================================= +# Orchestration Error Type +# ============================================================================= + +class TdxAttestationError(Exception): + """ + Raised when TDX attestation verification fails. + """ + pass + + +# ============================================================================= +# Orchestration Result Type +# ============================================================================= + +@dataclass +class TdxAttestationResult: + """ + Result of TDX attestation verification. + + Contains the verified quote, measurements, and TCB status. + """ + quote: QuoteV4 + pck_chain: PCKCertificateChain + pck_extensions: PckExtensions + collateral: TdxCollateral + tcb_level: TcbLevel + measurements: list[str] # [MRTD, RTMR0, RTMR1, RTMR2, RTMR3] + tls_key_fp: str + hpke_public_key: str + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + +def verify_tdx_attestation( + attestation_doc: str, + is_compressed: bool = True, + config: Optional[TdxVerificationConfig] = None, +) -> TdxAttestationResult: + """ + Verify a TDX attestation document. + + This is the main entry point for TDX attestation verification. + It performs the complete verification flow: + + 1. Decode and decompress the attestation document + 2. Parse the TDX quote + 3. Verify cryptographic signatures (PCK chain, quote, QE report) + 4. Policy validation (XFAM, TD_ATTRIBUTES, SEAM, MR_SEAM whitelist) + 5. Fetch and validate collateral (PCK extensions, TCB status, QE identity, revocation) + 6. Extract measurements and report data + + Args: + attestation_doc: Base64-encoded attestation document + is_compressed: Whether the document is gzip compressed + config: Optional verification config. Uses sensible defaults if None. + + Returns: + TdxAttestationResult containing verified data + + Raises: + TdxAttestationError: If any verification step fails + """ + if config is None: + config = _DEFAULT_CONFIG + # Step 1: Decode the attestation document + try: + raw_bytes = base64.b64decode(attestation_doc) + except Exception as e: + raise TdxAttestationError(f"Failed to decode base64: {e}") from e + + if is_compressed: + try: + raw_bytes = safe_gzip_decompress(raw_bytes) + except Exception as e: + raise TdxAttestationError(f"Failed to decompress: {e}") from e + + # Step 2: Parse the TDX quote + try: + quote = parse_quote(raw_bytes) + except TdxQuoteParseError as e: + raise TdxAttestationError(f"Failed to parse TDX quote: {e}") from e + + # Step 3: Verify cryptographic signatures + try: + pck_chain = verify_tdx_quote(quote, raw_bytes) + except TdxVerificationError as e: + raise TdxAttestationError(f"TDX quote verification failed: {e}") from e + + # Step 4: Policy validation (mirrors Go's validate.TdxQuote) + if config.policy_options is not None: + policy_options = config.policy_options + else: + policy_options = PolicyOptions( + header=HeaderOptions( + qe_vendor_id=INTEL_QE_VENDOR_ID, + ), + td_quote_body=TdQuoteBodyOptions( + minimum_tee_tcb_svn=config.expected_minimum_tee_tcb_svn, + td_attributes=config.expected_td_attributes, + xfam=config.expected_xfam, + mr_config_id=b'\x00' * MR_CONFIG_ID_SIZE, + mr_owner=b'\x00' * MR_OWNER_SIZE, + mr_owner_config=b'\x00' * MR_OWNER_CONFIG_SIZE, + any_mr_seam=config.accepted_mr_seams, + ), + ) + + try: + validate_tdx_policy(quote, policy_options) + except TdxValidationError as e: + raise TdxAttestationError(f"Policy validation failed: {e}") from e + + # Step 5: Validate collateral (PCK extensions, TCB status, QE identity, revocation) + try: + collateral_result = validate_collateral( + quote=quote, + pck_chain=pck_chain, + min_tcb_evaluation_data_number=config.min_tcb_evaluation_data_number, + ) + except CollateralError as e: + raise TdxAttestationError(f"Collateral validation failed: {e}") from e + + # Step 6: Extract measurements and report data + measurements = quote.td_quote_body.get_measurements() + measurements_hex = [m.hex() for m in measurements] + + report_data = quote.td_quote_body.report_data + required_len = TLS_KEY_FP_SIZE + HPKE_KEY_SIZE + if len(report_data) < required_len: + raise TdxAttestationError( + f"report_data too short: {len(report_data)} bytes, need at least {required_len}" + ) + tls_key_fp = report_data[0:TLS_KEY_FP_SIZE].hex() + hpke_public_key = report_data[TLS_KEY_FP_SIZE:TLS_KEY_FP_SIZE + HPKE_KEY_SIZE].hex() + + return TdxAttestationResult( + quote=quote, + pck_chain=pck_chain, + pck_extensions=collateral_result.pck_extensions, + collateral=collateral_result.collateral, + tcb_level=collateral_result.tcb_level, + measurements=measurements_hex, + tls_key_fp=tls_key_fp, + hpke_public_key=hpke_public_key, + ) + + +def verify_tdx_attestation_v2(attestation_doc: str) -> Verification: + """ + Verify TDX attestation document (v2 format) and return verification result. + + v2 format: report_data contains TLS key fingerprint (32 bytes) + HPKE public key (32 bytes). + + Args: + attestation_doc: Base64-encoded, gzip-compressed TDX quote + + Returns: + Verification containing measurements, public key fingerprint, and HPKE public key + + Raises: + ValueError: If verification fails + """ + try: + result = verify_tdx_attestation(attestation_doc, is_compressed=True) + except TdxAttestationError as e: + raise ValueError(f"TDX attestation verification failed: {e}") from e + + measurement = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=result.measurements, + ) + + return Verification( + measurement=measurement, + public_key_fp=result.tls_key_fp, + hpke_public_key=result.hpke_public_key, + ) + + +def verify_tdx_hardware( + hardware_measurements: list[HardwareMeasurement], + enclave_measurement: Measurement, +) -> HardwareMeasurement: + """ + Verify that the enclave's MRTD and RTMR0 match a known hardware platform. + + Args: + hardware_measurements: List of known-good hardware measurements from Sigstore + enclave_measurement: The measurement from the TDX enclave attestation + + Returns: + The matching HardwareMeasurement + + Raises: + HardwareMeasurementError: If no matching hardware platform is found + ValueError: If enclave measurement is invalid + """ + if enclave_measurement is None: + raise ValueError("enclave measurement is None") + + if enclave_measurement.type != PredicateType.TDX_GUEST_V2: + raise ValueError(f"unsupported enclave platform: {enclave_measurement.type}") + + if len(enclave_measurement.registers) != TDX_REGISTER_COUNT: + raise ValueError(f"expected {TDX_REGISTER_COUNT} TDX registers, got {len(enclave_measurement.registers)}") + + enclave_mrtd = enclave_measurement.registers[TDX_MRTD_IDX] + enclave_rtmr0 = enclave_measurement.registers[TDX_RTMR0_IDX] + + for hw in hardware_measurements: + if hw.mrtd == enclave_mrtd and hw.rtmr0 == enclave_rtmr0: + return hw + + raise HardwareMeasurementError("no matching hardware platform found") diff --git a/src/tinfoil/attestation/cert_utils.py b/src/tinfoil/attestation/cert_utils.py new file mode 100644 index 0000000..911e673 --- /dev/null +++ b/src/tinfoil/attestation/cert_utils.py @@ -0,0 +1,152 @@ +""" +Shared certificate utilities for TDX attestation verification. + +This module provides common certificate parsing and chain verification +functions used by both verify_tdx.py and collateral_tdx.py. +""" + +from datetime import datetime, timezone +from typing import List + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization + +from .intel_root_ca import get_intel_root_ca + +class CertificateChainError(Exception): + """Raised when certificate chain verification fails.""" + pass + + +def parse_pem_chain(pem_data: bytes) -> List[x509.Certificate]: + """ + Parse concatenated PEM certificates. + + Handles: + - Concatenated PEM certificates + - Leading/trailing whitespace and null bytes + - Trailing null bytes between certificates (common in TDX quotes) + + Args: + pem_data: PEM-encoded certificate chain (bytes) + + Returns: + List of parsed certificates in order + + Raises: + CertificateChainError: If parsing fails + """ + try: + return x509.load_pem_x509_certificates(pem_data) + except Exception as e: + raise CertificateChainError(f"Failed to parse PEM certificate chain: {e}") from e + + +def certs_to_pem(certs: List[x509.Certificate]) -> str: + """ + Convert list of certificates to concatenated PEM string. + + Args: + certs: List of certificates + + Returns: + Concatenated PEM string + """ + pem_parts = [] + for cert in certs: + pem_parts.append(cert.public_bytes(serialization.Encoding.PEM).decode("ascii")) + return "".join(pem_parts) + + +def verify_intel_chain( + certs: List[x509.Certificate], + chain_name: str = "Certificate chain", +) -> None: + """ + Verify a certificate chain against Intel SGX Root CA. + + Performs manual chain verification without requiring TLS extensions (SAN). + Intel PCK certificates don't have SAN extension, so we can't use + PolicyBuilder.build_client_verifier(). + + Verification steps: + 1. Verify root cert matches embedded Intel SGX Root CA (by public key) + 2. Verify each certificate's validity period + 3. Verify issuer certificates have BasicConstraints CA=True + 4. Verify each certificate was issued by the next cert in chain + + Args: + certs: Certificate chain [leaf, intermediate(s)..., root] + chain_name: Human-readable name for error messages + + Raises: + CertificateChainError: If chain verification fails + """ + if len(certs) < 2: + raise CertificateChainError( + f"{chain_name} must contain at least 2 certificates (leaf and root)" + ) + + intel_root = get_intel_root_ca() + + # Step 1: Verify root certificate matches Intel SGX Root CA by public key + chain_root = certs[-1] + chain_root_pubkey = chain_root.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + intel_root_pubkey = intel_root.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + if chain_root_pubkey != intel_root_pubkey: + raise CertificateChainError( + f"{chain_name} root certificate does not match Intel SGX Root CA" + ) + + # Step 2: Verify each certificate's validity period + now = datetime.now(timezone.utc) + for cert in certs: + if now < cert.not_valid_before_utc: + raise CertificateChainError( + f"{chain_name}: certificate not yet valid (not before {cert.not_valid_before_utc})" + ) + if now > cert.not_valid_after_utc: + raise CertificateChainError( + f"{chain_name}: certificate expired (not after {cert.not_valid_after_utc})" + ) + + # Step 3: Verify issuer certificates are CAs (BasicConstraints) + # Go's x509.Verify checks this implicitly; we must do it explicitly. + for cert in certs[1:]: + try: + bc = cert.extensions.get_extension_for_class(x509.BasicConstraints) + if not bc.value.ca: + raise CertificateChainError( + f"{chain_name}: intermediate/root certificate is not a CA" + ) + except x509.ExtensionNotFound: + raise CertificateChainError( + f"{chain_name}: intermediate/root certificate missing BasicConstraints extension" + ) + + # Step 4: Verify certificate chain signatures using verify_directly_issued_by + # Each cert[i] should be signed by cert[i+1] + for i in range(len(certs) - 1): + cert = certs[i] + issuer = certs[i + 1] + try: + cert.verify_directly_issued_by(issuer) + except Exception as e: + raise CertificateChainError( + f"{chain_name}: certificate chain signature verification failed: {e}" + ) + + # Step 5: Verify the chain's root is signed by the trusted Intel root + try: + chain_root.verify_directly_issued_by(intel_root) + except Exception as e: + raise CertificateChainError( + f"{chain_name}: root certificate verification against Intel SGX Root CA failed: {e}" + ) diff --git a/src/tinfoil/attestation/collateral_tdx.py b/src/tinfoil/attestation/collateral_tdx.py new file mode 100644 index 0000000..21a9a6b --- /dev/null +++ b/src/tinfoil/attestation/collateral_tdx.py @@ -0,0 +1,2105 @@ +""" +Intel TDX Collateral Fetching and TCB Validation. + +This module handles fetching and validating collateral from Intel's +Provisioning Certification Service (PCS) for TDX attestation: + +- TCB Info: Contains TCB levels and status for the platform +- QE Identity: Contains Quoting Enclave identity information + +Intel PCS API: + Base URL: https://api.trustedservices.intel.com/tdx/certification/v4 + TCB Info: /tcb?fmspc={fmspc} + QE Identity: /qe/identity +""" + +import base64 +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from enum import Enum +import json +import os +import stat +from typing import List, Optional, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from .abi_tdx import QuoteV4 + from .verify_tdx import PCKCertificateChain +from urllib.parse import unquote + +from cryptography import x509 +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature +import platformdirs +import requests + +from .intel_root_ca import get_intel_root_ca +from .pck_extensions import PckExtensions, extract_pck_extensions, PckExtensionError +from .cert_utils import ( + parse_pem_chain, + certs_to_pem, + verify_intel_chain, + CertificateChainError, +) + + +# Intel PCS API base URLs +INTEL_PCS_TDX_BASE_URL = "https://api.trustedservices.intel.com/tdx/certification/v4" +INTEL_PCS_SGX_BASE_URL = "https://api.trustedservices.intel.com/sgx/certification/v4" + +# Type alias for the HTTP session used by fetch functions. +# Pass a requests.Session to reuse connections, inject retries, or mock in tests. +HttpSession = Optional[requests.Session] + + +def _http_get(url: str, timeout: float, session: HttpSession = None) -> requests.Response: + """Issue a GET request via *session* (or a one-shot ``requests.get``).""" + try: + if session is not None: + resp = session.get(url, timeout=timeout) + else: + resp = requests.get(url, timeout=timeout) + resp.raise_for_status() + return resp + except requests.RequestException as e: + raise CollateralError(f"HTTP GET {url} failed: {e}") from e + +# Cache directory for TDX collateral +_TDX_CACHE_DIR = platformdirs.user_cache_dir("tinfoil", "tinfoil") + + +class CollateralError(Exception): + """Raised when collateral fetching or validation fails.""" + pass + + +class TcbStatus(str, Enum): + """TCB status values from Intel PCS.""" + UP_TO_DATE = "UpToDate" + SW_HARDENING_NEEDED = "SWHardeningNeeded" + CONFIGURATION_NEEDED = "ConfigurationNeeded" + CONFIGURATION_AND_SW_HARDENING_NEEDED = "ConfigurationAndSWHardeningNeeded" + OUT_OF_DATE = "OutOfDate" + OUT_OF_DATE_CONFIGURATION_NEEDED = "OutOfDateConfigurationNeeded" + REVOKED = "Revoked" + + +# ============================================================================= +# CRL Data Structures +# ============================================================================= + +@dataclass +class PckCrl: + """ + PCK Certificate Revocation List from Intel PCS. + + Contains the parsed CRL and metadata for caching. + """ + crl: x509.CertificateRevocationList + ca_type: str # "platform" or "processor" + next_update: datetime + + +@dataclass +class RootCrl: + """ + Intel SGX Root CA Certificate Revocation List. + + Used to check if intermediate CA certificates have been revoked. + """ + crl: x509.CertificateRevocationList + next_update: datetime + + +# ============================================================================= +# TCB Info Data Structures +# ============================================================================= + +@dataclass +class TcbComponent: + """A single TCB component with SVN and metadata.""" + svn: int + category: str = "" + type: str = "" + + +@dataclass +class Tcb: + """TCB level data containing SVN components.""" + sgx_tcb_components: List[TcbComponent] # 16 SGX components + pce_svn: int + tdx_tcb_components: List[TcbComponent] # 16 TDX components + isv_svn: Optional[int] = None # ISV SVN for QE Identity TCB levels + + +@dataclass +class TcbLevel: + """A TCB level with status and advisory IDs.""" + tcb: Tcb + tcb_date: str + tcb_status: TcbStatus + advisory_ids: List[str] + + +@dataclass +class TdxModule: + """TDX module identity information.""" + mrsigner: bytes # 48 bytes + attributes: bytes + attributes_mask: bytes + + +@dataclass +class TdxModuleIdentity: + """TDX module identity with associated TCB levels.""" + id: str # e.g., "TDX_01", "TDX_03" + mrsigner: bytes + attributes: bytes + attributes_mask: bytes + tcb_levels: List[TcbLevel] + + +@dataclass +class TcbInfo: + """TCB Info structure from Intel PCS.""" + id: str # Must be "TDX" + version: int # Must be 3 + issue_date: datetime + next_update: datetime + fmspc: str + pce_id: str + tcb_type: int + tcb_evaluation_data_number: int + tdx_module: Optional[TdxModule] + tdx_module_identities: List[TdxModuleIdentity] + tcb_levels: List[TcbLevel] + + +@dataclass +class TdxTcbInfo: + """Top-level TCB Info response with signature.""" + tcb_info: TcbInfo + signature: str + + +# ============================================================================= +# QE Identity Data Structures +# ============================================================================= + +@dataclass +class EnclaveIdentity: + """Quoting Enclave identity information.""" + id: str # Must be "TD_QE" + version: int # Must be 2 + issue_date: datetime + next_update: datetime + tcb_evaluation_data_number: int + miscselect: bytes # 4 bytes + miscselect_mask: bytes # 4 bytes + attributes: bytes # 16 bytes + attributes_mask: bytes # 16 bytes + mrsigner: bytes # 32 bytes + isv_prod_id: int + tcb_levels: List[TcbLevel] + + +@dataclass +class QeIdentity: + """Top-level QE Identity response with signature.""" + enclave_identity: EnclaveIdentity + signature: str + + +# ============================================================================= +# Collateral Container +# ============================================================================= + +@dataclass +class TdxCollateral: + """Container for all TDX collateral data.""" + tcb_info: TdxTcbInfo + qe_identity: QeIdentity + tcb_info_raw: bytes # Raw JSON for signature verification + qe_identity_raw: bytes # Raw JSON for signature verification + pck_crl: Optional[PckCrl] = None # PCK CRL for revocation checking + root_crl: Optional[RootCrl] = None # Root CA CRL for intermediate revocation checking + tcb_info_issuer_chain: Optional[List[x509.Certificate]] = None # TCB Info signing chain for CRL checking + qe_identity_issuer_chain: Optional[List[x509.Certificate]] = None # QE Identity signing chain for CRL checking + + +# ============================================================================= +# Collateral Cache Helpers +# ============================================================================= + +# Cache directory permissions (owner-only) +_CACHE_DIR_MODE = stat.S_IRWXU # 0700 +_CACHE_FILE_MODE = stat.S_IRUSR | stat.S_IWUSR # 0600 + + +@dataclass +class CacheEntry: + """ + Cache entry containing body and issuer chain for signature verification. + + This allows re-verification of signatures on cache hits, not just on fetch. + """ + body: bytes # Raw response body (JSON for TCB/QE, DER for CRLs) + issuer_chain_pem: Optional[str] = None # PEM-encoded issuer cert chain + + +def _ensure_cache_dir() -> bool: + """ + Ensure cache directory exists with secure permissions (0700). + + Returns: + True if directory exists/was created, False on failure + """ + try: + os.makedirs(_TDX_CACHE_DIR, mode=_CACHE_DIR_MODE, exist_ok=True) + # Tighten permissions if directory already existed with looser perms + os.chmod(_TDX_CACHE_DIR, _CACHE_DIR_MODE) + return True + except OSError: + return False + + +def _get_tcb_info_cache_path(fmspc: str) -> str: + """Get cache file path for TCB Info (keyed by FMSPC).""" + return os.path.join(_TDX_CACHE_DIR, f"tdx_tcb_info_{fmspc.lower()}.json") + + +def _get_qe_identity_cache_path() -> str: + """Get cache file path for QE Identity (global, not FMSPC-specific).""" + return os.path.join(_TDX_CACHE_DIR, "tdx_qe_identity.json") + + +def _read_cache(cache_path: str) -> Optional[CacheEntry]: + """ + Read cached collateral from disk. + + Returns: + CacheEntry with body and optional issuer chain, or None on failure + """ + if not os.path.isfile(cache_path): + return None + try: + with open(cache_path, "rb") as f: + data = json.loads(f.read().decode("utf-8")) + return CacheEntry( + body=base64.b64decode(data["body"]), + issuer_chain_pem=data.get("issuer_chain_pem"), + ) + except (OSError, json.JSONDecodeError, KeyError, ValueError): + return None + + +def _write_cache(cache_path: str, entry: CacheEntry) -> None: + """ + Write collateral to disk cache atomically with secure permissions. + + Uses write-to-temp + rename pattern to prevent partial writes. + Sets file permissions to 0600 (owner read/write only). + """ + if not _ensure_cache_dir(): + return + + data = { + "body": base64.b64encode(entry.body).decode("ascii"), + } + if entry.issuer_chain_pem is not None: + data["issuer_chain_pem"] = entry.issuer_chain_pem + + tmp_path = cache_path + ".tmp" + try: + # Write to temp file + fd = os.open(tmp_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, _CACHE_FILE_MODE) + try: + os.write(fd, json.dumps(data).encode("utf-8")) + finally: + os.close(fd) + # Atomic rename + os.replace(tmp_path, cache_path) + except OSError: + # Clean up temp file on failure + try: + os.unlink(tmp_path) + except OSError: + pass + + +def _is_fresh(next_update: Optional[datetime]) -> bool: + """Check if a collateral item is still fresh (now < next_update).""" + if next_update is None: + return False + return datetime.now(timezone.utc) < next_update + + +def _is_tcb_info_fresh(tcb_info: TdxTcbInfo) -> bool: + """Check if TCB Info is still fresh (not expired).""" + return _is_fresh(tcb_info.tcb_info.next_update) + + +def _is_qe_identity_fresh(qe_identity: QeIdentity) -> bool: + """Check if QE Identity is still fresh (not expired).""" + return _is_fresh(qe_identity.enclave_identity.next_update) + + +def _get_crl_cache_path(ca_type: str) -> str: + """Get cache file path for PCK CRL (keyed by CA type).""" + return os.path.join(_TDX_CACHE_DIR, f"tdx_pck_crl_{ca_type.lower()}.json") + + +def _get_root_crl_cache_path() -> str: + """Get cache file path for Intel SGX Root CA CRL.""" + return os.path.join(_TDX_CACHE_DIR, "intel_sgx_root_ca_crl.json") + + +def _is_crl_fresh(crl: x509.CertificateRevocationList) -> bool: + """Check if CRL is still fresh (not expired).""" + return _is_fresh(crl.next_update_utc) + + +# ============================================================================= +# Intel-defined field sizes (hex character counts) +# ============================================================================= + +# Minimum required tcbEvaluationDataNumber. +# This prevents accepting collateral issued before critical security updates. +# See: https://www.intel.com/content/www/us/en/developer/topic-technology/software-security-guidance/trusted-computing-base-recovery-attestation.html +# +# This value should ideally be set dynamically using: +# from tinfoil.attestation.collateral_tdx import calculate_min_tcb_evaluation_data_number +# min_num = calculate_min_tcb_evaluation_data_number() +# +# That function queries Intel PCS and returns the lowest tcbEvaluationDataNumber +# whose TCB recovery event date is within the last year. +# +# Current value 18 corresponds to TCB recovery event date 2024-11-12. +DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER = 18 + +TDX_MRSIGNER_SIZE = 48 # TDX module MRSIGNER: 48 bytes +QE_MRSIGNER_SIZE = 32 # QE enclave MRSIGNER: 32 bytes +QE_ATTRIBUTES_SIZE = 16 # QE enclave attributes: 16 bytes +MISCSELECT_SIZE = 4 # MISCSELECT field: 4 bytes +TCB_COMPONENT_COUNT = 16 # Number of TCB SVN components +ECDSA_P256_COMPONENT_SIZE = 32 # 32 bytes per R or S component +ECDSA_SIGNATURE_SIZE = 64 # R || S (32 + 32 bytes) + + +# ============================================================================= +# Parsing Functions +# ============================================================================= + +def _parse_hex_bytes(hex_str: str) -> bytes: + """Parse hex string to bytes.""" + return bytes.fromhex(hex_str) + + +def _parse_datetime(dt_str: str) -> datetime: + """Parse ISO datetime string to datetime object.""" + # Handle format: "2025-12-17T06:24:56Z" + return datetime.fromisoformat(dt_str.replace("Z", "+00:00")) + + +def _parse_tcb_component(data: dict) -> TcbComponent: + """Parse a TCB component from JSON.""" + return TcbComponent( + svn=data.get("svn", 0), + category=data.get("category", ""), + type=data.get("type", ""), + ) + + +def _parse_tcb(data: dict) -> Tcb: + """Parse TCB from JSON.""" + sgx_components = [ + _parse_tcb_component(c) for c in data.get("sgxtcbcomponents", []) + ] + tdx_components = [ + _parse_tcb_component(c) for c in data.get("tdxtcbcomponents", []) + ] + + # Pad to 16 components if needed + while len(sgx_components) < TCB_COMPONENT_COUNT: + sgx_components.append(TcbComponent(svn=0)) + while len(tdx_components) < TCB_COMPONENT_COUNT: + tdx_components.append(TcbComponent(svn=0)) + + # Capture isvsvn if present (used by QE Identity TCB levels) + isv_svn = data.get("isvsvn") + + return Tcb( + sgx_tcb_components=sgx_components[:TCB_COMPONENT_COUNT], + pce_svn=data.get("pcesvn", 0), + tdx_tcb_components=tdx_components[:TCB_COMPONENT_COUNT], + isv_svn=isv_svn, + ) + + +def _parse_tcb_level(data: dict) -> TcbLevel: + """Parse a TCB level from JSON.""" + tcb_data = data.get("tcb", {}) + + # Handle simple TCB (just isvsvn) vs full TCB + if "sgxtcbcomponents" in tcb_data: + tcb = _parse_tcb(tcb_data) + else: + # Simple TCB with just isvsvn (used by QE Identity) + # Capture the isvsvn value for QE TCB level matching + isv_svn = tcb_data.get("isvsvn") + tcb = Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=isv_svn, + ) + + return TcbLevel( + tcb=tcb, + tcb_date=data.get("tcbDate", ""), + tcb_status=TcbStatus(data.get("tcbStatus", "UpToDate")), + advisory_ids=data.get("advisoryIDs", []), + ) + + +def _parse_tdx_module(data: dict) -> TdxModule: + """Parse TDX module from JSON.""" + return TdxModule( + mrsigner=_parse_hex_bytes(data.get("mrsigner", "00" * TDX_MRSIGNER_SIZE)), + attributes=_parse_hex_bytes(data.get("attributes", "")), + attributes_mask=_parse_hex_bytes(data.get("attributesMask", "")), + ) + + +def _parse_tdx_module_identity(data: dict) -> TdxModuleIdentity: + """Parse TDX module identity from JSON.""" + return TdxModuleIdentity( + id=data.get("id", ""), + mrsigner=_parse_hex_bytes(data.get("mrsigner", "00" * TDX_MRSIGNER_SIZE)), + attributes=_parse_hex_bytes(data.get("attributes", "")), + attributes_mask=_parse_hex_bytes(data.get("attributesMask", "")), + tcb_levels=[_parse_tcb_level(l) for l in data.get("tcbLevels", [])], + ) + + +def _parse_tcb_info(data: dict) -> TcbInfo: + """Parse TCB Info from JSON.""" + tdx_module = None + if "tdxModule" in data: + tdx_module = _parse_tdx_module(data["tdxModule"]) + + return TcbInfo( + id=data.get("id", ""), + version=data.get("version", 0), + issue_date=_parse_datetime(data.get("issueDate", "1970-01-01T00:00:00Z")), + next_update=_parse_datetime(data.get("nextUpdate", "1970-01-01T00:00:00Z")), + fmspc=data.get("fmspc", ""), + pce_id=data.get("pceId", ""), + tcb_type=data.get("tcbType", 0), + tcb_evaluation_data_number=data.get("tcbEvaluationDataNumber", 0), + tdx_module=tdx_module, + tdx_module_identities=[ + _parse_tdx_module_identity(m) for m in data.get("tdxModuleIdentities", []) + ], + tcb_levels=[_parse_tcb_level(l) for l in data.get("tcbLevels", [])], + ) + + +def _parse_enclave_identity(data: dict) -> EnclaveIdentity: + """Parse Enclave Identity from JSON.""" + return EnclaveIdentity( + id=data.get("id", ""), + version=data.get("version", 0), + issue_date=_parse_datetime(data.get("issueDate", "1970-01-01T00:00:00Z")), + next_update=_parse_datetime(data.get("nextUpdate", "1970-01-01T00:00:00Z")), + tcb_evaluation_data_number=data.get("tcbEvaluationDataNumber", 0), + miscselect=_parse_hex_bytes(data.get("miscselect", "00" * MISCSELECT_SIZE)), + miscselect_mask=_parse_hex_bytes(data.get("miscselectMask", "00" * MISCSELECT_SIZE)), + attributes=_parse_hex_bytes(data.get("attributes", "00" * QE_ATTRIBUTES_SIZE)), + attributes_mask=_parse_hex_bytes(data.get("attributesMask", "00" * QE_ATTRIBUTES_SIZE)), + mrsigner=_parse_hex_bytes(data.get("mrsigner", "00" * QE_MRSIGNER_SIZE)), + isv_prod_id=data.get("isvprodid", 0), + tcb_levels=[_parse_tcb_level(l) for l in data.get("tcbLevels", [])], + ) + + +def parse_tcb_info_response(response_bytes: bytes) -> TdxTcbInfo: + """ + Parse TCB Info response from Intel PCS. + + Args: + response_bytes: Raw JSON response bytes + + Returns: + Parsed TdxTcbInfo + + Raises: + CollateralError: If parsing fails + """ + try: + data = json.loads(response_bytes) + except json.JSONDecodeError as e: + raise CollateralError(f"Failed to parse TCB Info JSON: {e}") + + tcb_info = _parse_tcb_info(data.get("tcbInfo", {})) + + # Validate required fields + if tcb_info.id != "TDX": + raise CollateralError(f"TCB Info ID must be 'TDX', got '{tcb_info.id}'") + if tcb_info.version != 3: + raise CollateralError(f"TCB Info version must be 3, got {tcb_info.version}") + + return TdxTcbInfo( + tcb_info=tcb_info, + signature=data.get("signature", ""), + ) + + +def parse_qe_identity_response(response_bytes: bytes) -> QeIdentity: + """ + Parse QE Identity response from Intel PCS. + + Args: + response_bytes: Raw JSON response bytes + + Returns: + Parsed QeIdentity + + Raises: + CollateralError: If parsing fails + """ + try: + data = json.loads(response_bytes) + except json.JSONDecodeError as e: + raise CollateralError(f"Failed to parse QE Identity JSON: {e}") + + enclave_identity = _parse_enclave_identity(data.get("enclaveIdentity", {})) + + # Validate required fields + if enclave_identity.id != "TD_QE": + raise CollateralError( + f"QE Identity ID must be 'TD_QE', got '{enclave_identity.id}'" + ) + if enclave_identity.version != 2: + raise CollateralError( + f"QE Identity version must be 2, got {enclave_identity.version}" + ) + + return QeIdentity( + enclave_identity=enclave_identity, + signature=data.get("signature", ""), + ) + + +# ============================================================================= +# Collateral Signature Verification +# ============================================================================= + +def _parse_issuer_chain_header(header_value: str) -> List[x509.Certificate]: + """ + Parse the issuer certificate chain from a PCS response header. + + Intel PCS returns the chain as URL-encoded concatenated PEM certificates + in the TCB-Info-Issuer-Chain or SGX-Enclave-Identity-Issuer-Chain header. + + Args: + header_value: URL-encoded PEM certificate chain + + Returns: + List of parsed certificates (signing cert first, root last) + + Raises: + CollateralError: If parsing fails + """ + # URL-decode the header value + pem_data = unquote(header_value).encode('utf-8') + + try: + certs = parse_pem_chain(pem_data) + except CertificateChainError as e: + raise CollateralError(f"Failed to parse issuer chain certificate: {e}") + + if len(certs) < 2: + raise CollateralError( + f"Issuer chain should contain at least 2 certificates, got {len(certs)}" + ) + + return certs + + +def _verify_collateral_signature( + json_bytes: bytes, + json_key: str, + signature_hex: str, + signing_cert: x509.Certificate, + data_name: str, +) -> None: + """ + Verify the signature over collateral JSON data. + + Intel PCS signs the raw JSON string of the inner object (tcbInfo or + enclaveIdentity), not the outer wrapper. The signature is ECDSA P-256 + over SHA256 of the raw JSON bytes. + + Args: + json_bytes: Raw response bytes (full JSON) + json_key: Key name to extract for signing ("tcbInfo" or "enclaveIdentity") + signature_hex: Hex-encoded signature from response + signing_cert: Signing certificate (leaf of issuer chain) + data_name: Human-readable name for error messages + + Raises: + CollateralError: If signature verification fails + """ + # Extract the raw JSON string for the signed object + # We need to find the exact bytes of the inner JSON object as it appears + # in the response (including whitespace), since the signature is over + # the exact byte representation. + # + # The format is: {"tcbInfo":{...},"signature":"..."} + # We need to extract the exact "{...}" for tcbInfo + try: + json_str = json_bytes.decode('utf-8') + except UnicodeDecodeError as e: + raise CollateralError(f"Failed to decode {data_name} JSON: {e}") + + # Find the start of the inner object. + # Only match at the top-level (must be preceded by '{' or ',' ignoring whitespace). + key_pattern = f'"{json_key}":' + key_pos = json_str.find(key_pattern) + if key_pos == -1: + raise CollateralError(f"{data_name} JSON does not contain '{json_key}' key") + + # Verify the match is at the top level (preceded by '{' or ',' outside strings) + prefix = json_str[:key_pos].rstrip() + if not prefix or prefix[-1] not in ('{', ','): + raise CollateralError( + f"{data_name} JSON key '{json_key}' found in unexpected position" + ) + + # Find the start of the object value (skip whitespace after colon) + obj_start = key_pos + len(key_pattern) + while obj_start < len(json_str) and json_str[obj_start] in ' \t\n\r': + obj_start += 1 + + if obj_start >= len(json_str) or json_str[obj_start] != '{': + raise CollateralError(f"{data_name} '{json_key}' is not an object") + + # Use the stdlib JSON decoder to find the extent of the nested object. + # raw_decode parses from the given index and returns the end position, + # correctly handling string escaping, nesting, and all JSON edge cases. + try: + decoder = json.JSONDecoder() + _, obj_end = decoder.raw_decode(json_str, obj_start) + except json.JSONDecodeError as e: + raise CollateralError(f"{data_name} JSON has malformed '{json_key}' object: {e}") from e + + # Extract the signed data (exact bytes as they appear in the response) + signed_json = json_str[obj_start:obj_end].encode('utf-8') + + # Parse signature (hex-encoded raw R||S format, 64 bytes = 128 hex chars) + try: + sig_bytes = bytes.fromhex(signature_hex) + except ValueError as e: + raise CollateralError(f"{data_name} signature is not valid hex: {e}") + + if len(sig_bytes) != ECDSA_SIGNATURE_SIZE: + raise CollateralError( + f"{data_name} signature is {len(sig_bytes)} bytes, expected {ECDSA_SIGNATURE_SIZE}" + ) + + # Convert R||S to DER format + r = int.from_bytes(sig_bytes[0:ECDSA_P256_COMPONENT_SIZE], byteorder='big') + s = int.from_bytes(sig_bytes[ECDSA_P256_COMPONENT_SIZE:ECDSA_SIGNATURE_SIZE], byteorder='big') + signature_der = encode_dss_signature(r, s) + + # Verify ECDSA signature + try: + public_key = signing_cert.public_key() + public_key.verify(signature_der, signed_json, ec.ECDSA(hashes.SHA256())) + except InvalidSignature: + raise CollateralError( + f"{data_name} signature verification failed: signature does not match content" + ) + + +def verify_tcb_info_signature( + response_bytes: bytes, + tcb_info: TdxTcbInfo, + issuer_chain: List[x509.Certificate], +) -> None: + """ + Verify TCB Info signature against the issuer certificate chain. + + Args: + response_bytes: Raw TCB Info response bytes + tcb_info: Parsed TCB Info + issuer_chain: Issuer certificate chain from response header + + Raises: + CollateralError: If verification fails + """ + # Verify the issuer chain + try: + verify_intel_chain(issuer_chain, "TCB Info issuer chain") + except CertificateChainError as e: + raise CollateralError(str(e)) + + # Verify signature over tcbInfo JSON + _verify_collateral_signature( + json_bytes=response_bytes, + json_key="tcbInfo", + signature_hex=tcb_info.signature, + signing_cert=issuer_chain[0], + data_name="TCB Info", + ) + + +def verify_qe_identity_signature( + response_bytes: bytes, + qe_identity: QeIdentity, + issuer_chain: List[x509.Certificate], +) -> None: + """ + Verify QE Identity signature against the issuer certificate chain. + + Args: + response_bytes: Raw QE Identity response bytes + qe_identity: Parsed QE Identity + issuer_chain: Issuer certificate chain from response header + + Raises: + CollateralError: If verification fails + """ + # Verify the issuer chain + try: + verify_intel_chain(issuer_chain, "QE Identity issuer chain") + except CertificateChainError as e: + raise CollateralError(str(e)) + + # Verify signature over enclaveIdentity JSON + _verify_collateral_signature( + json_bytes=response_bytes, + json_key="enclaveIdentity", + signature_hex=qe_identity.signature, + signing_cert=issuer_chain[0], + data_name="QE Identity", + ) + + +# ============================================================================= +# Collateral Fetching +# ============================================================================= + +def fetch_tcb_info( + fmspc: str, timeout: float = 30.0, session: HttpSession = None, +) -> Tuple[TdxTcbInfo, bytes, List[x509.Certificate]]: + """ + Fetch TCB Info from Intel PCS, with caching. + + Cached TCB Info is stored on disk and reused until the next_update + timestamp expires. Signature is re-verified on every cache hit. + + Args: + fmspc: FMSPC value from PCK certificate (6 bytes hex) + timeout: Request timeout in seconds + session: Optional requests.Session for connection reuse / testing + + Returns: + Tuple of (parsed TdxTcbInfo, raw response bytes, issuer certificate chain) + + Raises: + CollateralError: If fetching, parsing, or signature verification fails + """ + cache_path = _get_tcb_info_cache_path(fmspc) + + # Try cache first + cached_entry = _read_cache(cache_path) + if cached_entry is not None and cached_entry.issuer_chain_pem is not None: + try: + cached_tcb_info = parse_tcb_info_response(cached_entry.body) + if _is_tcb_info_fresh(cached_tcb_info): + # Re-verify signature on cache hit + issuer_chain = parse_pem_chain(cached_entry.issuer_chain_pem.encode("utf-8")) + verify_tcb_info_signature( + cached_entry.body, cached_tcb_info, issuer_chain + ) + return cached_tcb_info, cached_entry.body, issuer_chain + except (CollateralError, CertificateChainError, ValueError, KeyError, OSError): + pass + + # Cache miss or stale - fetch from Intel PCS + url = f"{INTEL_PCS_TDX_BASE_URL}/tcb?fmspc={fmspc}" + response = _http_get(url, timeout, session) + + raw_bytes = response.content + tcb_info = parse_tcb_info_response(raw_bytes) + + # Extract and verify issuer certificate chain from response header + issuer_chain_header = response.headers.get("TCB-Info-Issuer-Chain") + if not issuer_chain_header: + raise CollateralError( + "TCB Info response missing TCB-Info-Issuer-Chain header" + ) + + issuer_chain = _parse_issuer_chain_header(issuer_chain_header) + verify_tcb_info_signature(raw_bytes, tcb_info, issuer_chain) + + # Write to cache with issuer chain for re-verification + cache_entry = CacheEntry( + body=raw_bytes, + issuer_chain_pem=certs_to_pem(issuer_chain), + ) + _write_cache(cache_path, cache_entry) + + return tcb_info, raw_bytes, issuer_chain + + +def fetch_qe_identity( + timeout: float = 30.0, session: HttpSession = None, +) -> Tuple[QeIdentity, bytes, List[x509.Certificate]]: + """ + Fetch QE Identity from Intel PCS, with caching. + + Cached QE Identity is stored on disk and reused until the next_update + timestamp expires. Signature is re-verified on every cache hit. + + Note: QE Identity is global (not FMSPC-specific) so there's only one + cache file shared across all platforms. + + Args: + timeout: Request timeout in seconds + session: Optional requests.Session for connection reuse / testing + + Returns: + Tuple of (parsed QeIdentity, raw response bytes, issuer certificate chain) + + Raises: + CollateralError: If fetching, parsing, or signature verification fails + """ + cache_path = _get_qe_identity_cache_path() + + # Try cache first + cached_entry = _read_cache(cache_path) + if cached_entry is not None and cached_entry.issuer_chain_pem is not None: + try: + cached_qe_identity = parse_qe_identity_response(cached_entry.body) + if _is_qe_identity_fresh(cached_qe_identity): + # Re-verify signature on cache hit + issuer_chain = parse_pem_chain(cached_entry.issuer_chain_pem.encode("utf-8")) + verify_qe_identity_signature( + cached_entry.body, cached_qe_identity, issuer_chain + ) + return cached_qe_identity, cached_entry.body, issuer_chain + except (CollateralError, CertificateChainError, ValueError, KeyError, OSError): + pass + + # Cache miss or stale - fetch from Intel PCS + url = f"{INTEL_PCS_TDX_BASE_URL}/qe/identity" + response = _http_get(url, timeout, session) + + raw_bytes = response.content + qe_identity = parse_qe_identity_response(raw_bytes) + + # Extract and verify issuer certificate chain from response header + issuer_chain_header = response.headers.get("SGX-Enclave-Identity-Issuer-Chain") + if not issuer_chain_header: + raise CollateralError( + "QE Identity response missing SGX-Enclave-Identity-Issuer-Chain header" + ) + + issuer_chain = _parse_issuer_chain_header(issuer_chain_header) + verify_qe_identity_signature(raw_bytes, qe_identity, issuer_chain) + + # Write to cache with issuer chain for re-verification + cache_entry = CacheEntry( + body=raw_bytes, + issuer_chain_pem=certs_to_pem(issuer_chain), + ) + _write_cache(cache_path, cache_entry) + + return qe_identity, raw_bytes, issuer_chain + + +def _verify_crl_signature( + crl: x509.CertificateRevocationList, + issuer_chain: List[x509.Certificate], + ca_type: str, +) -> None: + """ + Verify the CRL signature against the issuer certificate chain. + + Args: + crl: Parsed CRL + issuer_chain: Issuer certificate chain from response header + ca_type: CA type for error messages + + Raises: + CollateralError: If verification fails + """ + # Verify the issuer chain first + try: + verify_intel_chain(issuer_chain, f"PCK CRL ({ca_type}) issuer chain") + except CertificateChainError as e: + raise CollateralError(str(e)) + + # Verify CRL signature using the signing cert (first in chain) + signing_cert = issuer_chain[0] + hash_algo = crl.signature_hash_algorithm + if hash_algo is None: + raise CollateralError(f"PCK CRL ({ca_type}) has no signature hash algorithm") + try: + signing_cert.public_key().verify( + crl.signature, + crl.tbs_certlist_bytes, + ec.ECDSA(hash_algo), + ) + except InvalidSignature: + raise CollateralError( + f"PCK CRL ({ca_type}) signature verification failed" + ) + + +def fetch_pck_crl( + ca_type: str, timeout: float = 30.0, session: HttpSession = None, +) -> PckCrl: + """ + Fetch PCK CRL from Intel PCS, with caching. + + The CRL is fetched from the SGX certification API (shared with TDX). + Cached CRL is stored on disk and reused until next_update expires. + Signature is re-verified on every cache hit. + + Args: + ca_type: CA type - "platform" or "processor" + timeout: Request timeout in seconds + + Returns: + Parsed PckCrl + + Raises: + CollateralError: If fetching, parsing, or signature verification fails + """ + if ca_type not in ("platform", "processor"): + raise CollateralError(f"Invalid CA type: {ca_type}. Must be 'platform' or 'processor'") + + cache_path = _get_crl_cache_path(ca_type) + + # Try cache first + cached_entry = _read_cache(cache_path) + if cached_entry is not None and cached_entry.issuer_chain_pem is not None: + try: + cached_crl = x509.load_pem_x509_crl(cached_entry.body) + if _is_crl_fresh(cached_crl): + # Re-verify signature on cache hit + issuer_chain = parse_pem_chain(cached_entry.issuer_chain_pem.encode("utf-8")) + _verify_crl_signature(cached_crl, issuer_chain, ca_type) + next_update = cached_crl.next_update_utc + if next_update is None: + raise CollateralError(f"PCK CRL ({ca_type}) is missing next_update field") + return PckCrl(crl=cached_crl, ca_type=ca_type, next_update=next_update) + except (CollateralError, CertificateChainError, ValueError, KeyError, OSError): + pass + + # Cache miss or stale - fetch from Intel PCS + # CRL endpoint is under the SGX API (not TDX-specific) + url = f"{INTEL_PCS_SGX_BASE_URL}/pckcrl?ca={ca_type}" + response = _http_get(url, timeout, session) + + raw_bytes = response.content + + # Parse the PEM-encoded CRL + try: + crl = x509.load_pem_x509_crl(raw_bytes) + except ValueError as e: + raise CollateralError(f"Failed to parse PCK CRL ({ca_type}): {e}") from e + + # Extract and verify issuer certificate chain from response header + issuer_chain_header = response.headers.get("SGX-PCK-CRL-Issuer-Chain") + if not issuer_chain_header: + raise CollateralError( + f"PCK CRL ({ca_type}) response missing SGX-PCK-CRL-Issuer-Chain header" + ) + + issuer_chain = _parse_issuer_chain_header(issuer_chain_header) + _verify_crl_signature(crl, issuer_chain, ca_type) + + # Write to cache with issuer chain for re-verification + cache_entry = CacheEntry( + body=raw_bytes, + issuer_chain_pem=certs_to_pem(issuer_chain), + ) + _write_cache(cache_path, cache_entry) + + next_update = crl.next_update_utc + if next_update is None: + raise CollateralError(f"PCK CRL ({ca_type}) is missing next_update field") + + return PckCrl(crl=crl, ca_type=ca_type, next_update=next_update) + + +# Intel SGX Root CA CRL URL (from the certificate's CRL Distribution Point) +INTEL_SGX_ROOT_CA_CRL_URL = "https://certificates.trustedservices.intel.com/IntelSGXRootCA.der" + + +def _verify_root_crl_signature(crl: x509.CertificateRevocationList) -> None: + """ + Verify the Root CA CRL signature against Intel SGX Root CA. + + Args: + crl: Parsed CRL + + Raises: + CollateralError: If signature verification fails + """ + intel_root = get_intel_root_ca() + hash_algo = crl.signature_hash_algorithm + if hash_algo is None: + raise CollateralError("Intel SGX Root CA CRL has no signature hash algorithm") + try: + intel_root.public_key().verify( + crl.signature, + crl.tbs_certlist_bytes, + ec.ECDSA(hash_algo), + ) + except InvalidSignature: + raise CollateralError("Intel SGX Root CA CRL signature verification failed") + + +def fetch_root_ca_crl(timeout: float = 30.0, session: HttpSession = None) -> RootCrl: + """ + Fetch Intel SGX Root CA CRL, with caching. + + This CRL lists revoked intermediate CA certificates (Platform CA, Processor CA). + Used to verify that the PCK certificate chain's intermediate CA is not revoked. + Signature is re-verified on every cache hit. + + Args: + timeout: Request timeout in seconds + + Returns: + Parsed RootCrl + + Raises: + CollateralError: If fetching, parsing, or signature verification fails + """ + cache_path = _get_root_crl_cache_path() + + # Try cache first + cached_entry = _read_cache(cache_path) + if cached_entry is not None: + try: + cached_crl = x509.load_der_x509_crl(cached_entry.body) + if _is_crl_fresh(cached_crl): + # Re-verify signature on cache hit + _verify_root_crl_signature(cached_crl) + next_update = cached_crl.next_update_utc + if next_update is None: + raise CollateralError("Intel SGX Root CA CRL is missing next_update field") + return RootCrl(crl=cached_crl, next_update=next_update) + except (CollateralError, CertificateChainError, ValueError, KeyError, OSError): + pass + + # Cache miss or stale - fetch from Intel + response = _http_get(INTEL_SGX_ROOT_CA_CRL_URL, timeout, session) + + raw_bytes = response.content + + # Parse the DER-encoded CRL + try: + crl = x509.load_der_x509_crl(raw_bytes) + except Exception as e: + raise CollateralError(f"Failed to parse Intel SGX Root CA CRL: {e}") + + # Verify CRL is signed by the Intel SGX Root CA + _verify_root_crl_signature(crl) + + # Write to cache (no issuer chain needed - verified against embedded root) + cache_entry = CacheEntry(body=raw_bytes) + _write_cache(cache_path, cache_entry) + + next_update = crl.next_update_utc + if next_update is None: + raise CollateralError("Intel SGX Root CA CRL is missing next_update field") + + return RootCrl(crl=crl, next_update=next_update) + + +def _determine_pck_ca_type(pck_cert: x509.Certificate) -> str: + """ + Determine which CA issued the PCK certificate. + + The PCK certificate issuer CN indicates the CA type: + - "Intel SGX PCK Platform CA" -> "platform" + - "Intel SGX PCK Processor CA" -> "processor" + + Args: + pck_cert: PCK certificate from the quote + + Returns: + CA type string ("platform" or "processor") + + Raises: + CollateralError: If CA type cannot be determined + """ + try: + issuer = pck_cert.issuer + for attr in issuer: + if attr.oid == x509.oid.NameOID.COMMON_NAME: + cn = attr.value + if "Platform" in cn: + return "platform" + elif "Processor" in cn: + return "processor" + raise CollateralError( + f"Could not determine PCK CA type from issuer: {issuer}" + ) + except Exception as e: + raise CollateralError(f"Failed to determine PCK CA type: {e}") + + +def fetch_collateral( + pck_extensions: PckExtensions, + pck_cert: x509.Certificate, + timeout: float = 30.0, + session: HttpSession = None, +) -> TdxCollateral: + """ + Fetch all required collateral from Intel PCS. + + Args: + pck_extensions: PCK certificate extensions containing FMSPC + pck_cert: PCK certificate (needed for CRL fetching) + timeout: Request timeout in seconds + session: Optional requests.Session for connection reuse / testing + + Returns: + TdxCollateral containing all fetched data + + Raises: + CollateralError: If fetching fails + """ + tcb_info, tcb_info_raw, tcb_info_issuer_chain = fetch_tcb_info(pck_extensions.fmspc, timeout, session) + qe_identity, qe_identity_raw, qe_identity_issuer_chain = fetch_qe_identity(timeout, session) + + # Fetch CRLs for revocation checking + ca_type = _determine_pck_ca_type(pck_cert) + pck_crl = fetch_pck_crl(ca_type, timeout, session) + root_crl = fetch_root_ca_crl(timeout, session) + + return TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=tcb_info_raw, + qe_identity_raw=qe_identity_raw, + pck_crl=pck_crl, + root_crl=root_crl, + tcb_info_issuer_chain=tcb_info_issuer_chain, + qe_identity_issuer_chain=qe_identity_issuer_chain, + ) + + +# ============================================================================= +# TCB Evaluation Data Number Calculation +# ============================================================================= + +@dataclass +class TcbEvalNumber: + """A single TCB evaluation data number with its dates.""" + tcb_evaluation_data_number: int + tcb_recovery_event_date: datetime + tcb_date: datetime + + +@dataclass +class TcbEvaluationDataNumbers: + """Response from the tcbevaluationdatanumbers endpoint.""" + id: str # "TDX" + version: int + issue_date: datetime + next_update: datetime + tcb_eval_numbers: List[TcbEvalNumber] + signature: str + + +def _parse_tcb_eval_numbers_response(response_bytes: bytes) -> TcbEvaluationDataNumbers: + """ + Parse the tcbevaluationdatanumbers response from Intel PCS. + + Args: + response_bytes: Raw JSON response bytes + + Returns: + Parsed TcbEvaluationDataNumbers + + Raises: + CollateralError: If parsing fails + """ + try: + data = json.loads(response_bytes) + except json.JSONDecodeError as e: + raise CollateralError(f"Failed to parse TCB evaluation data numbers JSON: {e}") + + inner = data.get("tcbEvaluationDataNumbers", {}) + + tcb_eval_numbers = [] + for item in inner.get("tcbEvalNumbers", []): + tcb_eval_numbers.append(TcbEvalNumber( + tcb_evaluation_data_number=item.get("tcbEvaluationDataNumber", 0), + tcb_recovery_event_date=_parse_datetime(item.get("tcbRecoveryEventDate", "1970-01-01T00:00:00Z")), + tcb_date=_parse_datetime(item.get("tcbDate", "1970-01-01T00:00:00Z")), + )) + + return TcbEvaluationDataNumbers( + id=inner.get("id", ""), + version=inner.get("version", 0), + issue_date=_parse_datetime(inner.get("issueDate", "1970-01-01T00:00:00Z")), + next_update=_parse_datetime(inner.get("nextUpdate", "1970-01-01T00:00:00Z")), + tcb_eval_numbers=tcb_eval_numbers, + signature=data.get("signature", ""), + ) + + +def verify_tcb_eval_numbers_signature( + response_bytes: bytes, + data: TcbEvaluationDataNumbers, + issuer_chain: List[x509.Certificate], +) -> None: + """ + Verify TCB evaluation data numbers signature against the issuer certificate chain. + + Args: + response_bytes: Raw response bytes + data: Parsed TcbEvaluationDataNumbers + issuer_chain: Issuer certificate chain from response header + + Raises: + CollateralError: If verification fails + """ + try: + verify_intel_chain(issuer_chain, "TCB eval numbers issuer chain") + except CertificateChainError as e: + raise CollateralError(str(e)) + + _verify_collateral_signature( + json_bytes=response_bytes, + json_key="tcbEvaluationDataNumbers", + signature_hex=data.signature, + signing_cert=issuer_chain[0], + data_name="TCB evaluation data numbers", + ) + + +def fetch_tcb_evaluation_data_numbers( + timeout: float = 30.0, session: HttpSession = None, +) -> TcbEvaluationDataNumbers: + """ + Fetch TCB evaluation data numbers from Intel PCS. + + This endpoint returns all TCB evaluation data numbers with their + TCB recovery event dates, which can be used to determine the minimum + acceptable tcbEvaluationDataNumber based on age. + + Args: + timeout: Request timeout in seconds + session: Optional requests.Session for connection reuse / testing + + Returns: + Parsed TcbEvaluationDataNumbers + + Raises: + CollateralError: If fetching or parsing fails + """ + url = f"{INTEL_PCS_TDX_BASE_URL}/tcbevaluationdatanumbers" + response = _http_get(url, timeout, session) + + raw_bytes = response.content + data = _parse_tcb_eval_numbers_response(raw_bytes) + + issuer_chain_header = response.headers.get("TCB-Evaluation-Data-Numbers-Issuer-Chain") + if not issuer_chain_header: + raise CollateralError( + "TCB evaluation data numbers response missing " + "TCB-Evaluation-Data-Numbers-Issuer-Chain header" + ) + + issuer_chain = _parse_issuer_chain_header(issuer_chain_header) + verify_tcb_eval_numbers_signature(raw_bytes, data, issuer_chain) + + return data + + +def calculate_min_tcb_evaluation_data_number( + max_age_days: int = 365, + timeout: float = 30.0, + session: HttpSession = None, +) -> int: + """ + Calculate the minimum acceptable tcbEvaluationDataNumber based on age. + + This function queries Intel PCS to find the lowest tcbEvaluationDataNumber + whose TCB recovery event date is within the specified maximum age. This + ensures that collateral from older TCB recovery events is rejected. + + The response signature is verified against the Intel SGX root certificate + when the issuer chain header is present. + + Args: + max_age_days: Maximum age in days (default: 365 = 1 year) + timeout: Request timeout in seconds + session: Optional requests.Session for connection reuse / testing + + Returns: + Minimum acceptable tcbEvaluationDataNumber + + Raises: + CollateralError: If calculation fails or no valid number found + + Example: + >>> min_num = calculate_min_tcb_evaluation_data_number() + >>> print(f"Minimum acceptable tcbEvaluationDataNumber: {min_num}") + """ + cutoff_date = datetime.now(timezone.utc) - timedelta(days=max_age_days) + + data = fetch_tcb_evaluation_data_numbers(timeout, session=session) + + if not data.tcb_eval_numbers: + raise CollateralError("No TCB evaluation data numbers found in Intel PCS response") + + # Sort by number ascending to find the lowest acceptable + sorted_numbers = sorted(data.tcb_eval_numbers, key=lambda x: x.tcb_evaluation_data_number) + + # Find the lowest number whose tcbRecoveryEventDate is >= cutoff_date + for item in sorted_numbers: + if item.tcb_recovery_event_date >= cutoff_date: + return item.tcb_evaluation_data_number + + # Fail closed if all numbers are too old + highest = sorted_numbers[-1] + raise CollateralError( + f"All TCB evaluation data numbers are older than {max_age_days} days. " + f"Most recent is {highest.tcb_evaluation_data_number} from " + f"{highest.tcb_recovery_event_date}." + ) + + +# ============================================================================= +# TCB Level Matching and Validation +# ============================================================================= + +def is_cpu_svn_higher_or_equal( + pck_cert_cpu_svn: bytes, + sgx_tcb_components: List[TcbComponent], +) -> bool: + """ + Check if PCK certificate CPU SVN is >= TCB level SGX components. + + Args: + pck_cert_cpu_svn: CPU SVN from PCK certificate (16 bytes) + sgx_tcb_components: SGX TCB components from TCB level + + Returns: + True if all components are >= + """ + if len(pck_cert_cpu_svn) != len(sgx_tcb_components): + return False + + for i, component in enumerate(sgx_tcb_components): + if pck_cert_cpu_svn[i] < component.svn: + return False + + return True + + +def is_tdx_tcb_svn_higher_or_equal( + tee_tcb_svn: bytes, + tdx_tcb_components: List[TcbComponent], +) -> bool: + """ + Check if TEE TCB SVN is >= TCB level TDX components. + + Args: + tee_tcb_svn: TEE TCB SVN from quote body (16 bytes) + tdx_tcb_components: TDX TCB components from TCB level + + Returns: + True if all relevant components are >= + """ + if len(tee_tcb_svn) != len(tdx_tcb_components): + return False + + # If teeTcbSvn[1] > 0, skip first 2 bytes (module-specific behavior) + start = 0 + if len(tee_tcb_svn) > 1 and tee_tcb_svn[1] > 0: + start = 2 + + for i in range(start, len(tee_tcb_svn)): + if tee_tcb_svn[i] < tdx_tcb_components[i].svn: + return False + + return True + + +def get_matching_tcb_level( + tcb_levels: List[TcbLevel], + tee_tcb_svn: bytes, + pck_cert_pce_svn: int, + pck_cert_cpu_svn: bytes, +) -> Optional[TcbLevel]: + """ + Find the matching TCB level for the given SVN values. + + TCB levels are ordered from newest to oldest. Returns the first + level where all three checks pass. + + Args: + tcb_levels: List of TCB levels from TCB Info + tee_tcb_svn: TEE TCB SVN from quote body + pck_cert_pce_svn: PCE SVN from PCK certificate + pck_cert_cpu_svn: CPU SVN from PCK certificate + + Returns: + Matching TcbLevel or None if no match found + """ + for level in tcb_levels: + if (is_cpu_svn_higher_or_equal(pck_cert_cpu_svn, level.tcb.sgx_tcb_components) and + pck_cert_pce_svn >= level.tcb.pce_svn and + is_tdx_tcb_svn_higher_or_equal(tee_tcb_svn, level.tcb.tdx_tcb_components)): + return level + + return None + + +def get_matching_qe_tcb_level( + tcb_levels: List[TcbLevel], + isv_svn: int, +) -> Optional[TcbLevel]: + """ + Find the matching TCB level for the QE ISV SVN. + + TCB levels are ordered from newest to oldest. Returns the first + level where the report's ISV SVN >= the level's ISV SVN. + + Args: + tcb_levels: List of TCB levels from QE Identity + isv_svn: ISV SVN from QE report + + Returns: + Matching TcbLevel or None if no match found + """ + for level in tcb_levels: + # QE TCB levels use isvsvn in the tcb structure + # The level matches if report's ISV SVN >= level's ISV SVN + if level.tcb.isv_svn is not None and isv_svn >= level.tcb.isv_svn: + return level + + return None + + +def validate_tcb_status( + tcb_info: TcbInfo, + tee_tcb_svn: bytes, + pck_extensions: PckExtensions, + strict_mode: bool = False, +) -> TcbLevel: + """ + Validate TCB status against Intel's published levels. + + Performs the following checks: + 1. Cross-validates FMSPC and PCE_ID between PCK cert and TCB Info + 2. Finds matching TCB level for platform SVN values + 3. Checks TCB status (REVOKED and OUT_OF_DATE always rejected) + + Args: + tcb_info: Parsed TCB Info from Intel PCS + tee_tcb_svn: TEE TCB SVN from quote body + pck_extensions: PCK certificate extensions + strict_mode: If True, reject SW_HARDENING_NEEDED and CONFIGURATION_NEEDED + + Returns: + Matching TcbLevel + + Raises: + CollateralError: If TCB status is not acceptable + """ + # Cross-check FMSPC between PCK certificate and TCB Info + pck_fmspc = pck_extensions.fmspc.lower() + tcb_fmspc = tcb_info.fmspc.lower() + if pck_fmspc != tcb_fmspc: + raise CollateralError( + f"FMSPC mismatch: PCK certificate has {pck_fmspc}, " + f"TCB Info has {tcb_fmspc}" + ) + + # Cross-check PCE_ID between PCK certificate and TCB Info + pck_pceid = pck_extensions.pceid.lower() + tcb_pceid = tcb_info.pce_id.lower() + if pck_pceid != tcb_pceid: + raise CollateralError( + f"PCE_ID mismatch: PCK certificate has {pck_pceid}, " + f"TCB Info has {tcb_pceid}" + ) + + matching_level = get_matching_tcb_level( + tcb_levels=tcb_info.tcb_levels, + tee_tcb_svn=tee_tcb_svn, + pck_cert_pce_svn=pck_extensions.tcb.pce_svn, + pck_cert_cpu_svn=pck_extensions.tcb.cpu_svn, + ) + + if matching_level is None: + raise CollateralError( + "No matching TCB level found for the quote's SVN values" + ) + + # Check status - always reject REVOKED and OUT_OF_DATE + if matching_level.tcb_status == TcbStatus.REVOKED: + raise CollateralError("TCB status is REVOKED - platform is not trusted") + + if matching_level.tcb_status == TcbStatus.OUT_OF_DATE: + raise CollateralError("TCB status is OUT_OF_DATE - platform needs update") + + if matching_level.tcb_status == TcbStatus.OUT_OF_DATE_CONFIGURATION_NEEDED: + raise CollateralError( + "TCB status is OUT_OF_DATE_CONFIGURATION_NEEDED - platform needs update" + ) + + # In strict mode, also reject statuses that indicate security advisories + if strict_mode: + if matching_level.tcb_status == TcbStatus.SW_HARDENING_NEEDED: + advisories = ", ".join(matching_level.advisory_ids) or "none listed" + raise CollateralError( + f"TCB status is SW_HARDENING_NEEDED (strict mode). " + f"Advisories: {advisories}" + ) + + if matching_level.tcb_status == TcbStatus.CONFIGURATION_NEEDED: + advisories = ", ".join(matching_level.advisory_ids) or "none listed" + raise CollateralError( + f"TCB status is CONFIGURATION_NEEDED (strict mode). " + f"Advisories: {advisories}" + ) + + if matching_level.tcb_status == TcbStatus.CONFIGURATION_AND_SW_HARDENING_NEEDED: + advisories = ", ".join(matching_level.advisory_ids) or "none listed" + raise CollateralError( + f"TCB status is CONFIGURATION_AND_SW_HARDENING_NEEDED (strict mode). " + f"Advisories: {advisories}" + ) + + return matching_level + + +def get_tdx_module_identity( + tcb_info: TcbInfo, + tee_tcb_svn: bytes, +) -> TdxModuleIdentity: + """ + Find the matching TDX module identity based on TEE_TCB_SVN. + + The module ID is derived from TEE_TCB_SVN[1] (major version): + - TEE_TCB_SVN[1] = 0x03 -> Module ID = "TDX_03" + - TEE_TCB_SVN[1] = 0x01 -> Module ID = "TDX_01" + + Args: + tcb_info: Parsed TCB Info from Intel PCS + tee_tcb_svn: TEE TCB SVN from quote body (16 bytes) + + Returns: + Matching TdxModuleIdentity + + Raises: + CollateralError: If no module identities in TCB Info, tee_tcb_svn is + malformed, or module version is unknown + """ + if not tcb_info.tdx_module_identities: + raise CollateralError("TCB Info has no TDX module identities") + + if len(tee_tcb_svn) < 2: + raise CollateralError( + f"TEE_TCB_SVN is too short ({len(tee_tcb_svn)} bytes, expected >= 2)" + ) + + # Extract major version and form module ID + # TEE_TCB_SVN[0] = minor SVN, TEE_TCB_SVN[1] = major SVN + major_version = tee_tcb_svn[1] + module_id = f"TDX_{major_version:02d}" + + # Find matching module identity + for module_identity in tcb_info.tdx_module_identities: + if module_identity.id == module_id: + return module_identity + + known_ids = [m.id for m in tcb_info.tdx_module_identities] + raise CollateralError( + f"Unknown TDX module version {module_id}. Known module identities: {known_ids}" + ) + + +def validate_tdx_module_identity( + tcb_info: TcbInfo, + tee_tcb_svn: bytes, + mr_signer_seam: bytes, + seam_attributes: bytes, +) -> None: + """ + Validate TDX module identity against Intel's published identities. + + This validates: + - MR_SIGNER_SEAM matches the expected module signer + - SEAM_ATTRIBUTES match under the attribute mask + - Module-specific TCB level matches the minor version (TEE_TCB_SVN[0]) + + Args: + tcb_info: Parsed TCB Info from Intel PCS + tee_tcb_svn: TEE TCB SVN from quote body (16 bytes) + mr_signer_seam: MR_SIGNER_SEAM from quote body (48 bytes) + seam_attributes: SEAM_ATTRIBUTES from quote body (8 bytes) + + Raises: + CollateralError: If module identity validation fails + """ + module_identity = get_tdx_module_identity(tcb_info, tee_tcb_svn) + + # Verify MR_SIGNER_SEAM matches + if mr_signer_seam != module_identity.mrsigner: + raise CollateralError( + f"TDX module MR_SIGNER_SEAM does not match expected value. " + f"Got {mr_signer_seam.hex()}, expected {module_identity.mrsigner.hex()}" + ) + + # Verify SEAM_ATTRIBUTES match under mask + # Pad seam_attributes to match mask length if needed + mask = module_identity.attributes_mask + attrs = seam_attributes + + # Handle length mismatch by padding with zeros + if len(attrs) < len(mask): + attrs = attrs + b'\x00' * (len(mask) - len(attrs)) + if len(mask) < len(attrs): + mask = mask + b'\x00' * (len(attrs) - len(mask)) + + report_attrs_masked = bytes(a & b for a, b in zip(attrs, mask)) + expected_attrs_masked = bytes( + a & b for a, b in zip(module_identity.attributes, mask) + ) + if report_attrs_masked != expected_attrs_masked: + raise CollateralError( + f"TDX module SEAM_ATTRIBUTES do not match expected value under mask. " + f"Got {seam_attributes.hex()}, expected {module_identity.attributes.hex()} " + f"(mask {module_identity.attributes_mask.hex()})" + ) + + # Find matching module-specific TCB level + # TEE_TCB_SVN[0] = minor SVN, TEE_TCB_SVN[1] = major SVN + # The minor version is used for module TCB matching + minor_version = tee_tcb_svn[0] + + for level in module_identity.tcb_levels: + # Module TCB levels should have isv_svn (minor version) set + if level.tcb.isv_svn is not None and minor_version >= level.tcb.isv_svn: + # Check status - reject insecure statuses consistently with + # validate_tcb_status and validate_qe_identity + if level.tcb_status == TcbStatus.REVOKED: + raise CollateralError( + f"TDX module TCB status is REVOKED for version " + f"{tee_tcb_svn[1]}.{minor_version}" + ) + if level.tcb_status == TcbStatus.OUT_OF_DATE: + raise CollateralError( + f"TDX module TCB status is OUT_OF_DATE for version " + f"{tee_tcb_svn[1]}.{minor_version}" + ) + if level.tcb_status == TcbStatus.OUT_OF_DATE_CONFIGURATION_NEEDED: + raise CollateralError( + f"TDX module TCB status is OUT_OF_DATE_CONFIGURATION_NEEDED " + f"for version {tee_tcb_svn[1]}.{minor_version}" + ) + return + + # No matching TCB level found - this is an error (matches Go behavior) + raise CollateralError( + f"Could not find a TDX Module Identity TCB Level matching " + f"the TDX Module's ISVSVN ({minor_version})" + ) + + +def validate_qe_identity( + qe_identity: EnclaveIdentity, + qe_report_isv_svn: int, + qe_report_mrsigner: bytes, + qe_report_miscselect: bytes, + qe_report_attributes: bytes, + qe_report_isvprodid: int, +) -> TcbLevel: + """ + Validate QE identity against Intel's published identity. + + This performs comprehensive validation of the Quoting Enclave: + - MRSIGNER must match exactly + - MISCSELECT must match under mask + - Attributes must match under mask + - ISV ProdID must match exactly + - ISV SVN must meet minimum threshold for a TCB level + + Args: + qe_identity: Parsed QE Identity from Intel PCS + qe_report_isv_svn: ISV SVN from QE report + qe_report_mrsigner: MRSIGNER from QE report (32 bytes) + qe_report_miscselect: MISCSELECT from QE report (4 bytes) + qe_report_attributes: Attributes from QE report (16 bytes) + qe_report_isvprodid: ISV ProdID from QE report + + Returns: + Matching TcbLevel + + Raises: + CollateralError: If QE identity validation fails + """ + # Verify MRSIGNER matches exactly + if qe_report_mrsigner != qe_identity.mrsigner: + raise CollateralError( + f"QE report MRSIGNER does not match expected value. " + f"Got {qe_report_mrsigner.hex()}, expected {qe_identity.mrsigner.hex()}" + ) + + # Verify MISCSELECT matches under mask + # (qe_report_miscselect & mask) == (qe_identity.miscselect & mask) + report_miscselect_masked = bytes( + a & b for a, b in zip(qe_report_miscselect, qe_identity.miscselect_mask) + ) + expected_miscselect_masked = bytes( + a & b for a, b in zip(qe_identity.miscselect, qe_identity.miscselect_mask) + ) + if report_miscselect_masked != expected_miscselect_masked: + raise CollateralError( + f"QE report MISCSELECT does not match expected value under mask. " + f"Got {qe_report_miscselect.hex()}, expected {qe_identity.miscselect.hex()} " + f"(mask {qe_identity.miscselect_mask.hex()})" + ) + + # Verify Attributes match under mask + # (qe_report_attributes & mask) == (qe_identity.attributes & mask) + report_attributes_masked = bytes( + a & b for a, b in zip(qe_report_attributes, qe_identity.attributes_mask) + ) + expected_attributes_masked = bytes( + a & b for a, b in zip(qe_identity.attributes, qe_identity.attributes_mask) + ) + if report_attributes_masked != expected_attributes_masked: + raise CollateralError( + f"QE report Attributes do not match expected value under mask. " + f"Got {qe_report_attributes.hex()}, expected {qe_identity.attributes.hex()} " + f"(mask {qe_identity.attributes_mask.hex()})" + ) + + # Verify ISV ProdID matches exactly + if qe_report_isvprodid != qe_identity.isv_prod_id: + raise CollateralError( + f"QE report ISV ProdID does not match expected value. " + f"Got {qe_report_isvprodid}, expected {qe_identity.isv_prod_id}" + ) + + # Find matching TCB level + matching_level = get_matching_qe_tcb_level( + tcb_levels=qe_identity.tcb_levels, + isv_svn=qe_report_isv_svn, + ) + + if matching_level is None: + raise CollateralError( + f"No matching QE TCB level found for ISV SVN {qe_report_isv_svn}" + ) + + # Check status + if matching_level.tcb_status == TcbStatus.REVOKED: + raise CollateralError("QE TCB status is REVOKED") + + if matching_level.tcb_status == TcbStatus.OUT_OF_DATE: + raise CollateralError("QE TCB status is OUT_OF_DATE") + + if matching_level.tcb_status == TcbStatus.OUT_OF_DATE_CONFIGURATION_NEEDED: + raise CollateralError("QE TCB status is OUT_OF_DATE_CONFIGURATION_NEEDED") + + return matching_level + + +def check_collateral_freshness( + collateral: TdxCollateral, + min_tcb_evaluation_data_number: int, +) -> None: + """ + Check that collateral is not expired and meets freshness requirements. + + This performs the following checks: + 1. TCB Info has not expired (now < next_update) + 2. QE Identity has not expired (now < next_update) + 3. Both TCB Info and QE Identity must have tcbEvaluationDataNumber >= threshold + + The tcbEvaluationDataNumber is a monotonically increasing number that + Intel updates when new TCB recovery (TCB-R) events occur. The minimum + threshold ensures we don't accept collateral issued before critical + security updates. + + See: https://www.intel.com/content/www/us/en/developer/topic-technology/software-security-guidance/trusted-computing-base-recovery-attestation.html + + Args: + collateral: TDX collateral to check + min_tcb_evaluation_data_number: Minimum tcbEvaluationDataNumber threshold. + Collateral with a lower number is rejected. + + Raises: + CollateralError: If collateral is expired or too old + """ + now = datetime.now(timezone.utc) + + tcb_next_update = collateral.tcb_info.tcb_info.next_update + if now > tcb_next_update: + raise CollateralError( + f"TCB Info has expired (next update was {tcb_next_update})" + ) + + qe_next_update = collateral.qe_identity.enclave_identity.next_update + if now > qe_next_update: + raise CollateralError( + f"QE Identity has expired (next update was {qe_next_update})" + ) + + # Check tcbEvaluationDataNumber threshold + tcb_eval_num = collateral.tcb_info.tcb_info.tcb_evaluation_data_number + if tcb_eval_num < min_tcb_evaluation_data_number: + raise CollateralError( + f"TCB Info tcbEvaluationDataNumber ({tcb_eval_num}) is below " + f"minimum required ({min_tcb_evaluation_data_number}). " + f"Collateral may be outdated." + ) + + qe_eval_num = collateral.qe_identity.enclave_identity.tcb_evaluation_data_number + if qe_eval_num < min_tcb_evaluation_data_number: + raise CollateralError( + f"QE Identity tcbEvaluationDataNumber ({qe_eval_num}) is below " + f"minimum required ({min_tcb_evaluation_data_number}). " + f"Collateral may be outdated." + ) + + +def validate_certificate_revocation( + collateral: TdxCollateral, + pck_cert: x509.Certificate, + intermediate_cert: Optional[x509.Certificate] = None, +) -> None: + """ + Validate that certificates in the attestation chain have not been revoked. + + Per Intel DCAP spec Section 2.3, checks: + 1. PCK (leaf) certificate against the PCK CRL from Intel PCS + 2. Intermediate CA certificate against the Intel SGX Root CA CRL + 3. TCB Info signing certificate against the Intel SGX Root CA CRL + 4. QE Identity signing certificate against the Intel SGX Root CA CRL + + Args: + collateral: TDX collateral containing CRLs and issuer chains + pck_cert: PCK certificate to check + intermediate_cert: Intermediate CA certificate to check (optional) + + Raises: + CollateralError: If any certificate is revoked or CRL check fails + + TODO: Use cryptography.x509.verification integrated CRL checking when available. + As of cryptography 46.0.3, PolicyBuilder doesn't support CRL stores. + Track: https://github.com/pyca/cryptography/issues + """ + if collateral.pck_crl is None: + raise CollateralError( + "Cannot check certificate revocation: PCK CRL not available in collateral" + ) + + # --- Check PCK (leaf) certificate against PCK CRL --- + pck_crl = collateral.pck_crl.crl + + # Check PCK CRL freshness + now = datetime.now(timezone.utc) + pck_crl_next_update = pck_crl.next_update_utc + if pck_crl_next_update is not None and now > pck_crl_next_update: + raise CollateralError( + f"PCK CRL has expired (next update was {pck_crl_next_update})" + ) + + # Check if PCK certificate is revoked + pck_serial = pck_cert.serial_number + revoked_pck = pck_crl.get_revoked_certificate_by_serial_number(pck_serial) + + if revoked_pck is not None: + revocation_date = revoked_pck.revocation_date_utc + raise CollateralError( + f"PCK certificate has been revoked. " + f"Serial: {pck_serial:x}, Revocation date: {revocation_date}" + ) + + # --- Check certificates against Root CA CRL --- + # The Root CRL lists revoked intermediate CAs (Platform CA, Processor CA, + # TCB Signing CA, etc.) issued by Intel SGX Root CA. + if collateral.root_crl is not None: + root_crl = collateral.root_crl.crl + + # Check Root CRL freshness (once, shared across all checks below) + root_crl_next_update = root_crl.next_update_utc + if root_crl_next_update is not None and now > root_crl_next_update: + raise CollateralError( + f"Intel SGX Root CA CRL has expired (next update was {root_crl_next_update})" + ) + + # Check intermediate CA certificate (from PCK cert chain) + if intermediate_cert is not None: + _check_cert_against_crl( + root_crl, intermediate_cert, "Intermediate CA" + ) + + # Check TCB Info signing certificate (per DCAP spec Section 2.3: + # "Check if verification collaterals are on the CRL") + if collateral.tcb_info_issuer_chain: + _check_cert_against_crl( + root_crl, collateral.tcb_info_issuer_chain[0], "TCB Info signing" + ) + + # Check QE Identity signing certificate + if collateral.qe_identity_issuer_chain: + _check_cert_against_crl( + root_crl, collateral.qe_identity_issuer_chain[0], "QE Identity signing" + ) + + elif intermediate_cert is not None or collateral.tcb_info_issuer_chain or collateral.qe_identity_issuer_chain: + raise CollateralError( + "Cannot check certificate revocation: Root CRL not available in collateral" + ) + + +def _check_cert_against_crl( + crl: x509.CertificateRevocationList, + cert: x509.Certificate, + cert_name: str, +) -> None: + """ + Check if a certificate has been revoked according to a CRL. + + Args: + crl: Certificate Revocation List to check against + cert: Certificate to check + cert_name: Human-readable name for error messages + + Raises: + CollateralError: If the certificate has been revoked + """ + serial = cert.serial_number + revoked = crl.get_revoked_certificate_by_serial_number(serial) + if revoked is not None: + revocation_date = revoked.revocation_date_utc + raise CollateralError( + f"{cert_name} certificate has been revoked. " + f"Serial: {serial:x}, Revocation date: {revocation_date}" + ) + + +# ============================================================================= +# Collateral Validation Orchestration +# ============================================================================= + +@dataclass +class CollateralValidationResult: + """ + Result of collateral validation. + + Attributes: + collateral: The fetched TDX collateral + tcb_level: The matching TCB level for the platform + pck_extensions: Extracted PCK certificate extensions + """ + collateral: TdxCollateral + tcb_level: TcbLevel + pck_extensions: PckExtensions + + +def validate_collateral( + quote: "QuoteV4", + pck_chain: "PCKCertificateChain", + min_tcb_evaluation_data_number: int = DEFAULT_MIN_TCB_EVALUATION_DATA_NUMBER, +) -> CollateralValidationResult: + """ + Validate all collateral for a TDX quote. + + This function orchestrates the complete collateral validation flow: + 1. Extract PCK extensions (FMSPC, TCB components) + 2. Fetch collateral from Intel PCS (TCB Info, QE Identity) + 3. Check collateral freshness + 4. Validate certificate revocation + 5. Validate TCB status + 6. Validate TDX module identity + 7. Validate QE identity + + Args: + quote: Parsed TDX QuoteV4 + pck_chain: Extracted PCK certificate chain + min_tcb_evaluation_data_number: Minimum required tcbEvaluationDataNumber + + Returns: + CollateralValidationResult with validated collateral and TCB level + + Raises: + CollateralError: If any collateral validation step fails + """ + + + # Step 1: Extract PCK extensions + try: + pck_extensions = extract_pck_extensions(pck_chain.pck_cert) + except PckExtensionError as e: + raise CollateralError(f"Failed to extract PCK extensions: {e}") + + # Step 2: Fetch collateral from Intel PCS + collateral = fetch_collateral(pck_extensions, pck_chain.pck_cert) + + # Step 3: Check collateral freshness + check_collateral_freshness(collateral, min_tcb_evaluation_data_number) + + # Step 4: Validate certificate revocation + validate_certificate_revocation( + collateral, pck_chain.pck_cert, pck_chain.intermediate_cert + ) + + # Step 5: Validate TCB status + tcb_level = validate_tcb_status( + collateral.tcb_info.tcb_info, + quote.td_quote_body.tee_tcb_svn, + pck_extensions, + ) + + # Step 6: Validate TDX module identity + validate_tdx_module_identity( + collateral.tcb_info.tcb_info, + quote.td_quote_body.tee_tcb_svn, + quote.td_quote_body.mr_signer_seam, + quote.td_quote_body.seam_attributes, + ) + + # Step 7: Validate QE identity + qe_report = quote.signed_data.certification_data.qe_report_data + if qe_report is None: + raise CollateralError("Quote missing QE report certification data") + qe_parsed = qe_report.qe_report_parsed + miscselect_bytes = qe_parsed.misc_select.to_bytes(4, byteorder='little') + validate_qe_identity( + collateral.qe_identity.enclave_identity, + qe_parsed.isv_svn, + qe_parsed.mr_signer, + miscselect_bytes, + qe_parsed.attributes, + qe_parsed.isv_prod_id, + ) + + return CollateralValidationResult( + collateral=collateral, + tcb_level=tcb_level, + pck_extensions=pck_extensions, + ) diff --git a/src/tinfoil/attestation/intel_root_ca.py b/src/tinfoil/attestation/intel_root_ca.py new file mode 100644 index 0000000..bd31a36 --- /dev/null +++ b/src/tinfoil/attestation/intel_root_ca.py @@ -0,0 +1,80 @@ +""" +Intel SGX Root CA certificate for TDX attestation verification. + +This module provides the embedded Intel SGX Provisioning Certification Root CA +certificate, which is the trust anchor for verifying TDX attestation quotes. + +Certificate chain for TDX: + Intel SGX Root CA (this file - trust anchor) + └─► Intel SGX PCK Platform CA (intermediate, from collateral) + └─► Intel SGX PCK Certificate (leaf, embedded in quote) + +Source: https://certificates.trustedservices.intel.com/Intel_SGX_Provisioning_Certification_RootCA.pem +""" + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization + +# Intel SGX Provisioning Certification Root CA +# Subject: CN=Intel SGX Root CA, O=Intel Corporation, L=Santa Clara, ST=CA, C=US +# Valid: 2018-05-21 to 2049-12-31 +# Key: ECDSA P-256 +INTEL_SGX_ROOT_CA_PEM = b"""-----BEGIN CERTIFICATE----- +MIICjzCCAjSgAwIBAgIUImUM1lqdNInzg7SVUr9QGzknBqwwCgYIKoZIzj0EAwIw +aDEaMBgGA1UEAwwRSW50ZWwgU0dYIFJvb3QgQ0ExGjAYBgNVBAoMEUludGVsIENv +cnBvcmF0aW9uMRQwEgYDVQQHDAtTYW50YSBDbGFyYTELMAkGA1UECAwCQ0ExCzAJ +BgNVBAYTAlVTMB4XDTE4MDUyMTEwNDUxMFoXDTQ5MTIzMTIzNTk1OVowaDEaMBgG +A1UEAwwRSW50ZWwgU0dYIFJvb3QgQ0ExGjAYBgNVBAoMEUludGVsIENvcnBvcmF0 +aW9uMRQwEgYDVQQHDAtTYW50YSBDbGFyYTELMAkGA1UECAwCQ0ExCzAJBgNVBAYT +AlVTMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEC6nEwMDIYZOj/iPWsCzaEKi7 +1OiOSLRFhWGjbnBVJfVnkY4u3IjkDYYL0MxO4mqsyYjlBalTVYxFP2sJBK5zlKOB +uzCBuDAfBgNVHSMEGDAWgBQiZQzWWp00ifODtJVSv1AbOScGrDBSBgNVHR8ESzBJ +MEegRaBDhkFodHRwczovL2NlcnRpZmljYXRlcy50cnVzdGVkc2VydmljZXMuaW50 +ZWwuY29tL0ludGVsU0dYUm9vdENBLmRlcjAdBgNVHQ4EFgQUImUM1lqdNInzg7SV +Ur9QGzknBqwwDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQEwCgYI +KoZIzj0EAwIDSQAwRgIhAOW/5QkR+S9CiSDcNoowLuPRLsWGf/Yi7GSX94BgwTwg +AiEA4J0lrHoMs+Xo5o/sX6O9QWxHRAvZUGOdRQ7cvqRXaqI= +-----END CERTIFICATE----- +""" + + +def get_intel_root_ca() -> x509.Certificate: + """ + Load and return the Intel SGX Root CA certificate. + + Returns: + Parsed X.509 certificate object + + Example: + >>> root_ca = get_intel_root_ca() + >>> print(root_ca.subject) + + """ + return x509.load_pem_x509_certificate(INTEL_SGX_ROOT_CA_PEM) + + +def get_intel_root_ca_public_key_der() -> bytes: + """ + Get the Intel SGX Root CA public key in DER format. + + This is useful for comparing against certificate chain anchors. + + Returns: + DER-encoded public key bytes + """ + cert = get_intel_root_ca() + return cert.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + +def get_intel_root_ca_der() -> bytes: + """ + Get the Intel SGX Root CA certificate in DER format. + + Returns: + DER-encoded certificate bytes + """ + cert = get_intel_root_ca() + return cert.public_bytes(serialization.Encoding.DER) diff --git a/src/tinfoil/attestation/pck_extensions.py b/src/tinfoil/attestation/pck_extensions.py new file mode 100644 index 0000000..749fd5d --- /dev/null +++ b/src/tinfoil/attestation/pck_extensions.py @@ -0,0 +1,293 @@ +""" +PCK Certificate Extension Parsing for Intel TDX attestation. + +This module extracts Intel SGX-specific extensions from PCK (Provisioning +Certification Key) certificates. These extensions contain critical information +for TCB (Trusted Computing Base) validation: + +- FMSPC: Firmware/Microcode Security Patch Cluster (6 bytes) +- PCEID: PCE ID (2 bytes) +- TCB: TCB components including PCE SVN and CPU SVN + +Intel SGX Extension OID hierarchy: + 1.2.840.113741.1.13.1 - SGX Extension (parent) + 1.2.840.113741.1.13.1.1 - PPID + 1.2.840.113741.1.13.1.2 - TCB + 1.2.840.113741.1.13.1.2.1-16 - TCB Components + 1.2.840.113741.1.13.1.2.17 - PCE SVN + 1.2.840.113741.1.13.1.2.18 - CPU SVN + 1.2.840.113741.1.13.1.3 - PCEID + 1.2.840.113741.1.13.1.4 - FMSPC +""" + +from contextlib import contextmanager +from dataclasses import dataclass + +from cryptography import x509 +from pyasn1.codec.der import decoder as der_decoder + +# Intel SGX OID base +_INTEL_SGX_OID_BASE = "1.2.840.113741.1.13.1" + +# Intel SGX Extension OIDs +OID_SGX_EXTENSION = x509.ObjectIdentifier(_INTEL_SGX_OID_BASE) +OID_PPID = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.1") +OID_TCB = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.2") +OID_PCEID = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.3") +OID_FMSPC = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.4") + +# TCB sub-OIDs +OID_PCE_SVN = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.2.17") +OID_CPU_SVN = x509.ObjectIdentifier(f"{_INTEL_SGX_OID_BASE}.2.18") + +# TCB component OID → index mapping (1-indexed OID to 0-indexed array position) +_TCB_COMPONENT_OID_INDEX = { + f"{_INTEL_SGX_OID_BASE}.2.{i}": i - 1 for i in range(1, 17) +} + +# Size constants (match go-tdx-guest/pcs/pcs.go) +PCK_CERT_EXTENSION_COUNT = 6 +SGX_EXTENSION_MIN_SIZE = 4 +TCB_EXTENSION_SIZE = 18 # 16 components + PCE SVN + CPU SVN +PPID_SIZE = 16 +CPU_SVN_SIZE = 16 +FMSPC_SIZE = 6 +PCEID_SIZE = 2 +TCB_COMPONENTS_COUNT = 16 + + +class PckExtensionError(Exception): + """Raised when PCK certificate extension parsing fails.""" + pass + + +@contextmanager +def _asn1_errors(label: str): + """Wrap unexpected ASN.1 exceptions as PckExtensionError.""" + try: + yield + except PckExtensionError: + raise + except Exception as e: + raise PckExtensionError(f"Unexpected ASN.1 structure in {label}: {e}") from e + + +@dataclass +class PckCertTCB: + """ + TCB (Trusted Computing Base) information from PCK certificate. + + Attributes: + pce_svn: PCE Security Version Number + cpu_svn: CPU SVN as raw bytes (16 bytes) + tcb_components: Individual TCB component SVNs (16 values) + """ + pce_svn: int + cpu_svn: bytes # 16 bytes + tcb_components: list[int] # 16 components + + def __str__(self) -> str: + return ( + f"PckCertTCB(pce_svn={self.pce_svn}, " + f"cpu_svn={self.cpu_svn.hex()}, " + f"components={self.tcb_components})" + ) + + +@dataclass +class PckExtensions: + """ + Intel SGX extensions extracted from a PCK certificate. + + Attributes: + ppid: Platform/Product ID as hex string (16 bytes) + tcb: TCB information + pceid: PCE ID as hex string (2 bytes) + fmspc: Firmware/Microcode FMSPC as hex string (6 bytes) + """ + ppid: str # 16 bytes hex + tcb: PckCertTCB + pceid: str # 2 bytes hex + fmspc: str # 6 bytes hex + + def __str__(self) -> str: + return ( + f"PckExtensions(fmspc={self.fmspc}, pceid={self.pceid}, " + f"ppid={self.ppid[:8]}...)" + ) + + +def extract_pck_extensions(cert: x509.Certificate) -> PckExtensions: + """ + Extract Intel SGX extensions from a PCK certificate. + + Args: + cert: PCK leaf certificate from the quote + + Returns: + PckExtensions containing FMSPC, PCEID, PPID, and TCB info + + Raises: + PckExtensionError: If required extensions are missing or malformed + """ + if len(cert.extensions) != PCK_CERT_EXTENSION_COUNT: + raise PckExtensionError( + f"PCK certificate has {len(cert.extensions)} extensions, " + f"expected {PCK_CERT_EXTENSION_COUNT}" + ) + + sgx_ext = None + for ext in cert.extensions: + if ext.oid == OID_SGX_EXTENSION: + sgx_ext = ext + break + + if sgx_ext is None: + raise PckExtensionError( + "PCK certificate does not contain Intel SGX extension " + f"(OID {OID_SGX_EXTENSION.dotted_string})" + ) + + return _parse_sgx_extension(sgx_ext.value.value) + + +def _der_decode(data: bytes, label: str = "value"): + """Decode DER data using pyasn1, wrapping errors as PckExtensionError. + + Checks for leftover bytes after decoding (matches Go's asn1.Unmarshal behavior). + """ + try: + result, remainder = der_decoder.decode(data) + except Exception as e: + raise PckExtensionError(f"Failed to decode ASN.1 {label}: {e}") from e + if remainder: + raise PckExtensionError( + f"Unexpected leftover bytes after decoding {label}: {len(remainder)} bytes" + ) + return result + + +def _octet_hex(component, name: str, expected_size: int) -> str: + """Extract bytes from a decoded ASN.1 value, validate size, return hex.""" + raw = bytes(component) + if len(raw) != expected_size: + raise PckExtensionError( + f"{name} has wrong size: expected {expected_size}, got {len(raw)}" + ) + return raw.hex() + + +def _parse_sgx_extension(raw_value: bytes) -> PckExtensions: + """Parse the SGX extension: a SEQUENCE OF SEQUENCE(OID, value).""" + outer_seq = _der_decode(raw_value, "SGX extension") + + if len(outer_seq) < SGX_EXTENSION_MIN_SIZE: + raise PckExtensionError( + f"SGX extension has {len(outer_seq)} elements, " + f"expected at least {SGX_EXTENSION_MIN_SIZE}" + ) + + ppid = None + tcb = None + pceid = None + fmspc = None + + with _asn1_errors("SGX extension"): + for item in outer_seq: + if len(item) < 2: + raise PckExtensionError( + f"Malformed SGX extension item: expected at least 2 fields, got {len(item)}" + ) + oid_str = str(item[0]) + + if oid_str == OID_PPID.dotted_string: + ppid = _octet_hex(item[1], "PPID", PPID_SIZE) + elif oid_str == OID_TCB.dotted_string: + tcb = _parse_tcb(item[1]) + elif oid_str == OID_PCEID.dotted_string: + pceid = _octet_hex(item[1], "PCEID", PCEID_SIZE) + elif oid_str == OID_FMSPC.dotted_string: + fmspc = _octet_hex(item[1], "FMSPC", FMSPC_SIZE) + + if fmspc is None: + raise PckExtensionError("FMSPC not found in PCK certificate") + if pceid is None: + raise PckExtensionError("PCEID not found in PCK certificate") + if ppid is None: + raise PckExtensionError("PPID not found in PCK certificate") + if tcb is None: + raise PckExtensionError("TCB not found in PCK certificate") + + return PckExtensions(ppid=ppid, tcb=tcb, pceid=pceid, fmspc=fmspc) + + +def _parse_tcb(value_component) -> PckCertTCB: + """ + Parse the TCB extension: a SEQUENCE of (OID, value) pairs for + 16 TCB components, PCE SVN, and CPU SVN. + + value_component is already a decoded pyasn1 Sequence from the parent decode. + """ + tcb_seq = value_component + + if len(tcb_seq) != TCB_EXTENSION_SIZE: + raise PckExtensionError( + f"TCB extension has {len(tcb_seq)} elements, " + f"expected {TCB_EXTENSION_SIZE}" + ) + + tcb_components = [0] * TCB_COMPONENTS_COUNT + pce_svn = 0 + cpu_svn = bytes(CPU_SVN_SIZE) + found_pce_svn = False + found_cpu_svn = False + found_components: set[int] = set() + + with _asn1_errors("TCB extension"): + for item in tcb_seq: + if len(item) < 2: + raise PckExtensionError( + f"Malformed TCB extension item: expected at least 2 fields, got {len(item)}" + ) + oid_str = str(item[0]) + + if oid_str == OID_PCE_SVN.dotted_string: + val = int(item[1]) + if val < 0 or val > 0xFFFF: + raise PckExtensionError( + f"PCE SVN value {val} out of uint16 range" + ) + pce_svn = val + found_pce_svn = True + elif oid_str == OID_CPU_SVN.dotted_string: + raw = bytes(item[1]) + if len(raw) != CPU_SVN_SIZE: + raise PckExtensionError( + f"CPU SVN has wrong size: expected {CPU_SVN_SIZE}, got {len(raw)}" + ) + cpu_svn = raw + found_cpu_svn = True + elif (idx := _TCB_COMPONENT_OID_INDEX.get(oid_str)) is not None: + val = int(item[1]) + if val < 0 or val > 0xFF: + raise PckExtensionError( + f"TCB component {idx + 1} value {val} out of byte range" + ) + tcb_components[idx] = val + found_components.add(idx) + else: + raise PckExtensionError( + f"Unrecognized OID in TCB extension: {oid_str}" + ) + + if not found_pce_svn: + raise PckExtensionError("PCE SVN not found in TCB extension") + if not found_cpu_svn: + raise PckExtensionError("CPU SVN not found in TCB extension") + if len(found_components) != TCB_COMPONENTS_COUNT: + missing = set(range(TCB_COMPONENTS_COUNT)) - found_components + raise PckExtensionError( + f"Missing TCB components: {sorted(i + 1 for i in missing)}" + ) + + return PckCertTCB(pce_svn=pce_svn, cpu_svn=cpu_svn, tcb_components=tcb_components) diff --git a/src/tinfoil/attestation/types.py b/src/tinfoil/attestation/types.py new file mode 100644 index 0000000..bce4704 --- /dev/null +++ b/src/tinfoil/attestation/types.py @@ -0,0 +1,168 @@ +""" +Shared types, errors, and protocol constants for attestation. + +This module is the canonical source for types used across TDX and SEV +attestation modules. It has no intra-package dependencies, so any module +can import from it without risk of circular imports. +""" + +import hashlib +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + + +# ============================================================================= +# Protocol-level constants (shared across TDX and SEV) +# ============================================================================= + +TLS_KEY_FP_SIZE = 32 # SHA-256 TLS public key fingerprint (bytes) +HPKE_KEY_SIZE = 32 # HPKE public key (bytes) + +# RTMR3 should always be zeros (48 bytes = 96 hex chars) +RTMR3_ZERO = "0" * 96 + +# Register layout constants per platform +TDX_REGISTER_COUNT = 5 # [mrtd, rtmr0, rtmr1, rtmr2, rtmr3] +TDX_MRTD_IDX = 0 +TDX_RTMR0_IDX = 1 +TDX_RTMR1_IDX = 2 +TDX_RTMR2_IDX = 3 +TDX_RTMR3_IDX = 4 +SEV_REGISTER_COUNT = 1 # [snp_measurement] +MULTIPLATFORM_REGISTER_COUNT = 3 # [snp_measurement, rtmr1, rtmr2] +MULTIPLATFORM_SNP_IDX = 0 +MULTIPLATFORM_RTMR1_IDX = 1 +MULTIPLATFORM_RTMR2_IDX = 2 + + +# ============================================================================= +# Predicate types +# ============================================================================= + +class PredicateType(str, Enum): + """Predicate types for attestation""" + SEV_GUEST_V2 = "https://tinfoil.sh/predicate/sev-snp-guest/v2" + TDX_GUEST_V2 = "https://tinfoil.sh/predicate/tdx-guest/v2" + SNP_TDX_MULTIPLATFORM_v1 = "https://tinfoil.sh/predicate/snp-tdx-multiplatform/v1" + HARDWARE_MEASUREMENTS_V1 = "https://tinfoil.sh/predicate/hardware-measurements/v1" + + +TDX_TYPES = (PredicateType.TDX_GUEST_V2,) + + +# ============================================================================= +# Errors +# ============================================================================= + +class AttestationError(Exception): + """Base class for attestation errors""" + pass + +class FormatMismatchError(AttestationError): + """Raised when attestation formats don't match""" + pass + +class MeasurementMismatchError(AttestationError): + """Raised when measurements don't match""" + pass + +class Rtmr3NotZeroError(AttestationError): + """Raised when RTMR3 is not zeros""" + pass + +class HardwareMeasurementError(AttestationError): + """Raised when hardware measurement verification fails""" + pass + + +# ============================================================================= +# Data types +# ============================================================================= + +@dataclass +class HardwareMeasurement: + """Represents hardware platform measurements (MRTD and RTMR0 for TDX)""" + id: str # platform@digest + mrtd: str + rtmr0: str + +@dataclass +class Measurement: + """Represents measurement data""" + type: PredicateType + registers: List[str] + + def fingerprint(self) -> str: + """ + Computes the SHA-256 hash of the predicate type and all measurement + registers. Always returns a 64-char hex digest regardless of the + number of registers, so callers get a uniform format. + """ + if not self.registers: + raise ValueError("Cannot compute fingerprint: no measurement registers") + + all_data = self.type.value + "".join(self.registers) + return hashlib.sha256(all_data.encode()).hexdigest() + + def assert_equal(self, other: 'Measurement') -> None: + """ + Checks if this measurement equals another measurement with multi-platform support + Raises appropriate error if they don't match + """ + # Direct comparison for same types + if self.type == other.type: + if len(self.registers) != len(other.registers) or self.registers != other.registers: + raise MeasurementMismatchError() + return + + # Multi-platform comparison support + if self.type == PredicateType.SNP_TDX_MULTIPLATFORM_v1: + if other.type == PredicateType.TDX_GUEST_V2 and len(other.registers) == TDX_REGISTER_COUNT: + if (len(self.registers) != MULTIPLATFORM_REGISTER_COUNT or + self.registers[MULTIPLATFORM_RTMR1_IDX] != other.registers[TDX_RTMR1_IDX] or + self.registers[MULTIPLATFORM_RTMR2_IDX] != other.registers[TDX_RTMR2_IDX]): + raise MeasurementMismatchError() + if other.registers[TDX_RTMR3_IDX] != RTMR3_ZERO: + raise Rtmr3NotZeroError(f"RTMR3 must be zeros, got {other.registers[TDX_RTMR3_IDX]}") + return + elif other.type == PredicateType.SEV_GUEST_V2 and len(other.registers) == SEV_REGISTER_COUNT: + if (len(self.registers) != MULTIPLATFORM_REGISTER_COUNT or + self.registers[MULTIPLATFORM_SNP_IDX] != other.registers[0]): + raise MeasurementMismatchError() + return + + # Reverse comparisons + if other.type == PredicateType.SNP_TDX_MULTIPLATFORM_v1: + try: + other.assert_equal(self) + return + except (FormatMismatchError, MeasurementMismatchError): + raise + + # If we get here, the formats are incompatible + raise FormatMismatchError() + + def __str__(self) -> str: + """Returns a human-readable string representation of the measurement""" + if self.type == PredicateType.SEV_GUEST_V2 and len(self.registers) == SEV_REGISTER_COUNT: + return f"Measurement(type={self.type.value}, snp_measurement={self.registers[0][:16]}...)" + + elif self.type == PredicateType.TDX_GUEST_V2 and len(self.registers) == TDX_REGISTER_COUNT: + labels = ["mrtd", "rtmr0", "rtmr1", "rtmr2", "rtmr3"] + parts = [f"{label}={reg[:16]}..." for label, reg in zip(labels, self.registers)] + return f"Measurement(type={self.type.value}, {', '.join(parts)})" + + elif self.type == PredicateType.SNP_TDX_MULTIPLATFORM_v1 and len(self.registers) == MULTIPLATFORM_REGISTER_COUNT: + labels = ["snp_measurement", "rtmr1", "rtmr2"] + parts = [f"{label}={reg[:16]}..." for label, reg in zip(labels, self.registers)] + return f"Measurement(type={self.type.value}, {', '.join(parts)})" + + return f"Measurement(type={self.type.value}, registers={len(self.registers)} items)" + +@dataclass +class Verification: + """Represents verification results""" + measurement: Measurement + public_key_fp: str + hpke_public_key: Optional[str] = None diff --git a/src/tinfoil/attestation/utils.py b/src/tinfoil/attestation/utils.py new file mode 100644 index 0000000..ccc4ef3 --- /dev/null +++ b/src/tinfoil/attestation/utils.py @@ -0,0 +1,36 @@ +""" +Shared utility functions for attestation modules. + +This module has no intra-package dependencies, so any module +can import from it without risk of circular imports. +""" + +import gzip +import io + +MAX_DECOMPRESSED_SIZE = 10 * 1024 * 1024 # 10 MiB + + +def safe_gzip_decompress(data: bytes, max_size: int = MAX_DECOMPRESSED_SIZE) -> bytes: + """Decompress gzip data with a size limit to prevent gzip bombs. + + Args: + data: Gzip-compressed bytes + max_size: Maximum allowed decompressed size + + Returns: + Decompressed bytes + + Raises: + ValueError: If decompressed data exceeds max_size or decompression fails + """ + try: + with gzip.GzipFile(fileobj=io.BytesIO(data)) as f: + result = f.read(max_size + 1) + except (OSError, EOFError) as e: + raise ValueError(f"Gzip decompression failed: {e}") from e + if len(result) > max_size: + raise ValueError( + f"Decompressed attestation exceeds maximum size ({max_size} bytes)" + ) + return result diff --git a/src/tinfoil/attestation/validate.py b/src/tinfoil/attestation/validate_sev.py similarity index 99% rename from src/tinfoil/attestation/validate.py rename to src/tinfoil/attestation/validate_sev.py index d1a9997..6e911cd 100644 --- a/src/tinfoil/attestation/validate.py +++ b/src/tinfoil/attestation/validate_sev.py @@ -1,16 +1,16 @@ from dataclasses import dataclass, field from typing import Optional, List, Dict -from .abi_sevsnp import ( +from .abi_sev import ( Report, TCBParts, SnpPolicy, SnpPlatformInfo, ReportSigner, ) -from .verify import CertificateChain +from .verify_sev import CertificateChain -@dataclass +@dataclass(frozen=True) class ValidationOptions: """ Verification options for an SEV-SNP attestation report. diff --git a/src/tinfoil/attestation/validate_tdx.py b/src/tinfoil/attestation/validate_tdx.py new file mode 100644 index 0000000..442eb26 --- /dev/null +++ b/src/tinfoil/attestation/validate_tdx.py @@ -0,0 +1,428 @@ +""" +TDX Policy Validation. + +This module provides pure policy validation functions for TDX quote fields. + +This module contains: +- Policy validation functions (validate_xfam, validate_td_attributes, etc.) +- Policy option dataclasses (PolicyOptions, HeaderOptions, TdQuoteBodyOptions) +- Policy-related constants (XFAM_FIXED*, TD_ATTRIBUTES_FIXED*, etc.) +""" + +import struct +from dataclasses import dataclass, field +from typing import Optional + +from .abi_tdx import ( + QuoteV4, + MR_SEAM_SIZE, + TD_ATTRIBUTES_SIZE, + XFAM_SIZE, + MR_TD_SIZE, + MR_CONFIG_ID_SIZE, + MR_OWNER_SIZE, + MR_OWNER_CONFIG_SIZE, + RTMR_SIZE, + RTMR_COUNT, + REPORT_DATA_SIZE, + QE_VENDOR_ID_SIZE, + SEAM_ATTRIBUTES_SIZE, + MR_SIGNER_SEAM_SIZE, + TEE_TCB_SVN_SIZE, +) + + +# ============================================================================= +# Policy Validation Error +# ============================================================================= + +class TdxValidationError(Exception): + """Raised when TDX policy validation fails.""" + pass + + +# ============================================================================= +# Policy Validation Constants +# ============================================================================= + +# XFAM fixed bit constraints +# If bit X is 1 in XFAM_FIXED1, it must be 1 in any XFAM +XFAM_FIXED1 = 0x00000003 +# If bit X is 0 in XFAM_FIXED0, it must be 0 in any XFAM +XFAM_FIXED0 = 0x0006DBE7 + +# TD_ATTRIBUTES fixed bit constraints +# No FIXED1 bits for TD_ATTRIBUTES currently (no bits are mandatory-set). + +# TD_ATTRIBUTES bit definitions +TD_ATTRIBUTES_DEBUG_BIT = 0x1 # Bit 0: DEBUG mode +TD_ATTRIBUTES_SEPT_VE_DIS = 1 << 28 # Bit 28: Disable EPT violation #VE +TD_ATTRIBUTES_PKS = 1 << 30 # Bit 30: Supervisor Protection Keys +TD_ATTRIBUTES_PERFMON = 1 << 63 # Bit 63: Performance monitoring + +# If bit X is 0 in TD_ATTRIBUTES_FIXED0, it must be 0 in any TD_ATTRIBUTES +# Supported bits: DEBUG, SEPT_VE_DIS, PKS, PERFMON +TD_ATTRIBUTES_FIXED0 = ( + TD_ATTRIBUTES_DEBUG_BIT | + TD_ATTRIBUTES_SEPT_VE_DIS | + TD_ATTRIBUTES_PKS | + TD_ATTRIBUTES_PERFMON +) + + + +# ============================================================================= +# Policy Validation Options +# ============================================================================= + +@dataclass(frozen=True) +class HeaderOptions: + """ + Validation options for TDX quote header fields. + Mirrors Go's validate.HeaderOptions struct. + + All fields are optional - set to None to skip the check. + """ + # Expected QE_VENDOR_ID (16 bytes, not checked if None) + qe_vendor_id: Optional[bytes] = None + + +@dataclass(frozen=True) +class TdQuoteBodyOptions: + """ + Validation options for TDX quote body fields. + Mirrors Go's validate.TdQuoteBodyOptions struct. + + All fields are optional - set to None to skip the check. + """ + # Minimum TEE TCB SVN (16 bytes, component-wise comparison) + minimum_tee_tcb_svn: Optional[bytes] = None + # Expected MR_SEAM (48 bytes) + mr_seam: Optional[bytes] = None + # Expected TD_ATTRIBUTES (8 bytes) + td_attributes: Optional[bytes] = None + # Expected XFAM (8 bytes) + xfam: Optional[bytes] = None + # Expected MR_TD (48 bytes) + mr_td: Optional[bytes] = None + # Expected MR_CONFIG_ID (48 bytes) + mr_config_id: Optional[bytes] = None + # Expected MR_OWNER (48 bytes) + mr_owner: Optional[bytes] = None + # Expected MR_OWNER_CONFIG (48 bytes) + mr_owner_config: Optional[bytes] = None + # Expected RTMRs (list of 4 x 48 bytes) + rtmrs: Optional[tuple[Optional[bytes], ...]] = None + # Expected REPORT_DATA (64 bytes) + report_data: Optional[bytes] = None + # Any permitted MR_TD values (list of 48-byte values) + any_mr_td: Optional[tuple[Optional[bytes], ...]] = None + # Any permitted MR_SEAM values (list of 48-byte values) + any_mr_seam: Optional[tuple[Optional[bytes], ...]] = None + + +@dataclass(frozen=True) +class PolicyOptions: + """ + Complete validation options for TDX quote policy validation. + Mirrors Go's validate.Options struct. + """ + header: HeaderOptions = field(default_factory=HeaderOptions) + td_quote_body: TdQuoteBodyOptions = field(default_factory=TdQuoteBodyOptions) + + +# ============================================================================= +# Policy Validation Helper Functions +# ============================================================================= + +def _check_option_length(name: str, expected: int, value: Optional[bytes]) -> None: + """Check field length if value is provided.""" + if value is not None and len(value) != expected: + raise TdxValidationError( + f"Option '{name}' length is {len(value)}, expected {expected}" + ) + + +def _check_options_lengths(options: PolicyOptions) -> None: + """Validate all option field lengths.""" + h = options.header + t = options.td_quote_body + + _check_option_length("qe_vendor_id", QE_VENDOR_ID_SIZE, h.qe_vendor_id) + _check_option_length("minimum_tee_tcb_svn", TEE_TCB_SVN_SIZE, t.minimum_tee_tcb_svn) + _check_option_length("mr_seam", MR_SEAM_SIZE, t.mr_seam) + _check_option_length("td_attributes", TD_ATTRIBUTES_SIZE, t.td_attributes) + _check_option_length("xfam", XFAM_SIZE, t.xfam) + _check_option_length("mr_td", MR_TD_SIZE, t.mr_td) + _check_option_length("mr_config_id", MR_CONFIG_ID_SIZE, t.mr_config_id) + _check_option_length("mr_owner", MR_OWNER_SIZE, t.mr_owner) + _check_option_length("mr_owner_config", MR_OWNER_CONFIG_SIZE, t.mr_owner_config) + _check_option_length("report_data", REPORT_DATA_SIZE, t.report_data) + + if t.rtmrs is not None: + if len(t.rtmrs) != RTMR_COUNT: + raise TdxValidationError( + f"Option 'rtmrs' has {len(t.rtmrs)} entries, expected {RTMR_COUNT}" + ) + for i, rtmr in enumerate(t.rtmrs): + if rtmr is not None and len(rtmr) != RTMR_SIZE: + raise TdxValidationError( + f"Option 'rtmrs[{i}]' length is {len(rtmr)}, expected {RTMR_SIZE}" + ) + + if t.any_mr_td is not None: + if not any(v is not None for v in t.any_mr_td): + raise TdxValidationError("Option 'any_mr_td' contains no non-None entries") + for i, mr_td in enumerate(t.any_mr_td): + if mr_td is not None and len(mr_td) != MR_TD_SIZE: + raise TdxValidationError( + f"Option 'any_mr_td[{i}]' length is {len(mr_td)}, expected {MR_TD_SIZE}" + ) + + if t.any_mr_seam is not None: + if not any(v is not None for v in t.any_mr_seam): + raise TdxValidationError("Option 'any_mr_seam' contains no non-None entries") + for i, mr_seam in enumerate(t.any_mr_seam): + if mr_seam is not None and len(mr_seam) != MR_SEAM_SIZE: + raise TdxValidationError( + f"Option 'any_mr_seam[{i}]' length is {len(mr_seam)}, expected {MR_SEAM_SIZE}" + ) + + # Mutual exclusion: exact-match and allowlist cannot both be set + if t.mr_td is not None and t.any_mr_td is not None and len(t.any_mr_td) > 0: + raise TdxValidationError( + "Cannot set both 'mr_td' and 'any_mr_td' - use one or the other" + ) + if t.mr_seam is not None and t.any_mr_seam is not None and len(t.any_mr_seam) > 0: + raise TdxValidationError( + "Cannot set both 'mr_seam' and 'any_mr_seam' - use one or the other" + ) + + +def _byte_check( + field_name: str, + given: bytes, + expected: Optional[bytes], +) -> None: + """Check exact byte match if expected is provided.""" + if expected is None: + return # Skip check + + if given != expected: + raise TdxValidationError( + f"Quote field {field_name} is {given.hex()}, expected {expected.hex()}" + ) + + +def _is_svn_higher_or_equal(quote_svn: bytes, min_svn: Optional[bytes]) -> bool: + """Component-wise SVN comparison. Returns True if min_svn is None.""" + if min_svn is None: + return True + if len(quote_svn) != len(min_svn): + return False + for q, m in zip(quote_svn, min_svn): + if q < m: + return False + return True + + +# ============================================================================= +# Policy Validation Core Functions +# ============================================================================= + +def validate_xfam(xfam: bytes) -> None: + """ + Validate XFAM fixed bit constraints. + + Args: + xfam: 8-byte XFAM from quote + + Raises: + TdxValidationError: If fixed bit constraints violated + """ + if len(xfam) != XFAM_SIZE: + raise TdxValidationError(f"XFAM size is {len(xfam)}, expected {XFAM_SIZE}") + + value = struct.unpack(' None: + """ + Validate TD_ATTRIBUTES fixed bit constraints. + + Per Intel DCAP spec Section 2.3.2: "Verify that all TD Under Debug + flags (i.e., the TDATTRIBUTES.TUD field in the TD Quote Body) are + set to zero." + + Args: + td_attributes: 8-byte TD_ATTRIBUTES from quote + + Raises: + TdxValidationError: If validation fails + """ + if len(td_attributes) != TD_ATTRIBUTES_SIZE: + raise TdxValidationError( + f"TD_ATTRIBUTES size is {len(td_attributes)}, expected {TD_ATTRIBUTES_SIZE}" + ) + + value = struct.unpack(' None: + """Validate that a field is all zeros with the expected size.""" + if len(value) != expected_size: + raise TdxValidationError( + f"{name} size is {len(value)}, expected {expected_size}" + ) + if value != b'\x00' * expected_size: + raise TdxValidationError( + f"{name} must be zero for {context}, got {value.hex()}" + ) + + +def validate_seam_attributes(seam_attributes: bytes) -> None: + """ + Validate SEAMATTRIBUTES is zero (required for TDX 1.0/1.5). + + Per Intel TDX DCAP Quoting Library API section 2.3.2. + + Args: + seam_attributes: 8-byte SEAMATTRIBUTES from quote + + Raises: + TdxValidationError: If not zero + """ + _validate_zero_field("SEAMATTRIBUTES", seam_attributes, SEAM_ATTRIBUTES_SIZE, "TDX 1.0/1.5") + + +def validate_mr_signer_seam(mr_signer_seam: bytes) -> None: + """ + Validate MRSIGNERSEAM is zero (required for Intel TDX Module). + + Per Intel TDX DCAP Quoting Library API section 2.3.2. + + Args: + mr_signer_seam: 48-byte MRSIGNERSEAM from quote + + Raises: + TdxValidationError: If not zero + """ + _validate_zero_field("MRSIGNERSEAM", mr_signer_seam, MR_SIGNER_SEAM_SIZE, "Intel TDX Module") + + +def _validate_exact_byte_matches(quote: QuoteV4, options: PolicyOptions) -> None: + """Validate exact byte matches for configured fields.""" + t = options.td_quote_body + h = options.header + body = quote.td_quote_body + + _byte_check("MR_SEAM", body.mr_seam, t.mr_seam) + _byte_check("TD_ATTRIBUTES", body.td_attributes, t.td_attributes) + _byte_check("XFAM", body.xfam, t.xfam) + _byte_check("MR_TD", body.mr_td, t.mr_td) + _byte_check("MR_CONFIG_ID", body.mr_config_id, t.mr_config_id) + _byte_check("MR_OWNER", body.mr_owner, t.mr_owner) + _byte_check("MR_OWNER_CONFIG", body.mr_owner_config, t.mr_owner_config) + _byte_check("REPORT_DATA", body.report_data, t.report_data) + _byte_check("QE_VENDOR_ID", quote.header.qe_vendor_id, h.qe_vendor_id) + + # RTMR checks + if t.rtmrs is not None: + if len(body.rtmrs) != len(t.rtmrs): + raise TdxValidationError( + f"RTMR count mismatch: quote has {len(body.rtmrs)}, " + f"policy expects {len(t.rtmrs)}" + ) + for i, (given, expected) in enumerate(zip(body.rtmrs, t.rtmrs)): + if expected is not None: + _byte_check(f"RTMR[{i}]", given, expected) + + # Any MR_TD check (at least one must match) + if t.any_mr_td is not None and len(t.any_mr_td) > 0: + mr_td = body.mr_td + if not any(mr_td == allowed for allowed in t.any_mr_td if allowed is not None): + raise TdxValidationError( + f"MR_TD {mr_td.hex()} does not match any allowed value" + ) + + # Any MR_SEAM check (at least one must match) + if t.any_mr_seam is not None and len(t.any_mr_seam) > 0: + mr_seam = body.mr_seam + if not any(mr_seam == allowed for allowed in t.any_mr_seam if allowed is not None): + raise TdxValidationError( + f"MR_SEAM {mr_seam.hex()} does not match any allowed value" + ) + + +def _validate_min_versions(quote: QuoteV4, options: PolicyOptions) -> None: + """Validate minimum version requirements.""" + t = options.td_quote_body + + # TEE TCB SVN check + if t.minimum_tee_tcb_svn is not None: + if not _is_svn_higher_or_equal(quote.td_quote_body.tee_tcb_svn, t.minimum_tee_tcb_svn): + raise TdxValidationError( + f"TEE_TCB_SVN {quote.td_quote_body.tee_tcb_svn.hex()} is less than " + f"minimum {t.minimum_tee_tcb_svn.hex()}" + ) + + + +def validate_tdx_policy(quote: QuoteV4, options: PolicyOptions) -> None: + """ + Validate a TDX QuoteV4 against policy options. + + This is the main entry point for TDX quote policy validation. + It performs policy-based validation only - no cryptographic verification + or collateral fetching. + + Args: + quote: Parsed TDX QuoteV4 + options: Validation options + + Raises: + TdxValidationError: If any validation check fails + """ + # Validate option field lengths + _check_options_lengths(options) + + # Fixed bit validations (always run) + validate_xfam(quote.td_quote_body.xfam) + validate_td_attributes(quote.td_quote_body.td_attributes) + + # SEAM validations (required for TDX 1.0/1.5) + validate_seam_attributes(quote.td_quote_body.seam_attributes) + validate_mr_signer_seam(quote.td_quote_body.mr_signer_seam) + + # Exact byte match validations (optional based on options) + _validate_exact_byte_matches(quote, options) + + # Minimum version validations (optional based on options) + _validate_min_versions(quote, options) diff --git a/src/tinfoil/attestation/verify.py b/src/tinfoil/attestation/verify_sev.py similarity index 89% rename from src/tinfoil/attestation/verify.py rename to src/tinfoil/attestation/verify_sev.py index 96fdf01..29319e8 100644 --- a/src/tinfoil/attestation/verify.py +++ b/src/tinfoil/attestation/verify_sev.py @@ -11,22 +11,32 @@ from OpenSSL import crypto import platformdirs -from .abi_sevsnp import (Report, ReportSigner, TCBParts) +from .abi_sev import (Report, ReportSigner, TCBParts) from .genoa_cert_chain import (ARK_CERT, ASK_CERT) from cryptography import x509 from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec, utils from cryptography.x509.oid import ObjectIdentifier +from pyasn1.codec.der import decoder as der_decoder +from pyasn1.type.univ import Integer as Asn1Integer +from pyasn1.type.char import IA5String as Asn1IA5String import warnings from cryptography.utils import CryptographyDeprecationWarning # Type alias for certificate extensions Extensions: TypeAlias = Dict[ObjectIdentifier, bytes] -# VCEK cache directory setup (can stay at module level) +# VCEK cache directory (created lazily on first use) _VCEK_CACHE_DIR = platformdirs.user_cache_dir("tinfoil", "tinfoil") -os.makedirs(_VCEK_CACHE_DIR, exist_ok=True) + + +def _ensure_vcek_cache_dir() -> None: + """Create the VCEK cache directory if it doesn't exist. Silently ignores failures.""" + try: + os.makedirs(_VCEK_CACHE_DIR, exist_ok=True) + except OSError: + pass class SnpOid: """OID extensions for the VCEK, used to verify attestation report""" @@ -107,8 +117,14 @@ def from_report(cls, report:Report) -> 'CertificateChain': response.raise_for_status() vcek_cert_data = response.content # Persist to cache so the next call is instant - with open(cache_path, "wb") as fh: - fh.write(vcek_cert_data) + _ensure_vcek_cache_dir() + try: + tmp_path = cache_path + ".tmp" + with open(tmp_path, "wb") as fh: + fh.write(vcek_cert_data) + os.replace(tmp_path, cache_path) + except OSError: + pass except requests.RequestException as e: raise ValueError(f"Failed to fetch VCEK certificate: {e}") from e @@ -260,8 +276,9 @@ def _validate_vcek_format(self): if (SnpOid.HWID not in extensions) or len(extensions[SnpOid.HWID]) != 64: # ChipIDSize raise ValueError(f"missing HWID extension for VCEK certificate") - if extensions[SnpOid.PRODUCT_NAME_1] != b'\x16\x05Genoa': - raise ValueError(f"unexpected PRODUCT_NAME_1 in VCEK certificate: {extensions[SnpOid.PRODUCT_NAME_1]}") + product_name = _decode_der_ia5string(extensions[SnpOid.PRODUCT_NAME_1]) + if product_name != "Genoa": + raise ValueError(f"unexpected PRODUCT_NAME_1 in VCEK certificate: {product_name}") def validate_vcek_tcb(self, tcb: TCBParts): """Validate the TCB extension in the VCEK certificate matches a given TCB""" @@ -310,16 +327,28 @@ def _get_certificate_extensions(cert: x509.Certificate) -> Extensions: return extensions def _decode_der_integer(der_bytes: bytes) -> int: - """Decode a DER-encoded INTEGER""" - if len(der_bytes) < 2 or der_bytes[0] != 0x02: - raise ValueError(f"Invalid DER INTEGER: {der_bytes.hex()}") - - length = der_bytes[1] - if len(der_bytes) != 2 + length: - raise ValueError(f"Invalid DER INTEGER length: {der_bytes.hex()}") - - # Convert the integer bytes to int (big-endian) - return int.from_bytes(der_bytes[2:2+length], byteorder='big') + """Decode a DER-encoded INTEGER, rejecting non-INTEGER types and trailing bytes.""" + try: + result, remainder = der_decoder.decode(der_bytes, asn1Spec=Asn1Integer()) + except Exception as e: + raise ValueError(f"Failed to decode DER integer: {e}") from e + if remainder: + raise ValueError( + f"Unexpected trailing bytes after DER integer: {len(remainder)} bytes" + ) + return int(result) + +def _decode_der_ia5string(der_bytes: bytes) -> str: + """Decode a DER-encoded IA5String, rejecting non-IA5String types and trailing bytes.""" + try: + result, remainder = der_decoder.decode(der_bytes, asn1Spec=Asn1IA5String()) + except Exception as e: + raise ValueError(f"Failed to decode DER IA5String: {e}") from e + if remainder: + raise ValueError( + f"Unexpected trailing bytes after DER IA5String: {len(remainder)} bytes" + ) + return str(result) def _validateAmdLocation(name: x509.Name) -> bool: """Validate that the certificate subject name matches AMD's expected values. @@ -370,10 +399,18 @@ def check_singleton_list(values: list[str], field_name: str, expected: str) -> b def _VCEKCertURL(productName: str, chip_id: bytes, reported_tcb: int) -> str: # TODO add support for other product names """Generate the VCEK certificate URL based on the product name, chip ID, and reported TCB""" + import urllib.parse parts = TCBParts.from_int(reported_tcb) base_url = "https://kds-proxy.tinfoil.sh/vcek/v1" chip_id_hex = binascii.hexlify(chip_id).decode('ascii') - return f"{base_url}/{productName}/{chip_id_hex}?blSPL={parts.bl_spl}&teeSPL={parts.tee_spl}&snpSPL={parts.snp_spl}&ucodeSPL={parts.ucode_spl}" + path = f"{base_url}/{urllib.parse.quote(productName, safe='')}/{chip_id_hex}" + query = urllib.parse.urlencode({ + "blSPL": parts.bl_spl, + "teeSPL": parts.tee_spl, + "snpSPL": parts.snp_spl, + "ucodeSPL": parts.ucode_spl, + }) + return f"{path}?{query}" def _verify_report_signature(vcek: x509.Certificate, report: Report) -> bool: """Verify the attestation report signature using VCEK's public key""" diff --git a/src/tinfoil/attestation/verify_tdx.py b/src/tinfoil/attestation/verify_tdx.py new file mode 100644 index 0000000..0c11891 --- /dev/null +++ b/src/tinfoil/attestation/verify_tdx.py @@ -0,0 +1,330 @@ +""" +TDX Quote cryptographic verification. + +This module implements the cryptographic verification of TDX attestation quotes, +including signature verification and certificate chain validation. + +Verification flow: +1. Extract PCK certificate chain from quote +2. Verify PCK chain against Intel SGX Root CA +3. Verify quote signature (ECDSA P-256 over Header || TdQuoteBody) +4. Verify QE report signature using PCK leaf certificate +5. Verify QE report data binding (attestation key hash) +""" + +import hashlib +from dataclasses import dataclass + + +from cryptography import x509 +from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature +from cryptography.hazmat.primitives.asymmetric.utils import Prehashed + +from .abi_tdx import ( + QuoteV4, + QUOTE_HEADER_START, + QUOTE_BODY_END, + SIGNATURE_SIZE, + ATTESTATION_KEY_SIZE, + ECDSA_P256_COMPONENT_SIZE, + SHA256_HASH_SIZE, + PCK_CERT_CHAIN_COUNT, +) +from .cert_utils import parse_pem_chain, verify_intel_chain, CertificateChainError + + +class TdxVerificationError(Exception): + """Raised when TDX quote verification fails.""" + pass + + +@dataclass +class PCKCertificateChain: + """ + PCK Certificate Chain extracted from the quote. + + Contains three certificates: + - PCK leaf certificate (signs the QE report) + - Intermediate CA certificate + - Root CA certificate (should match Intel SGX Root CA) + """ + pck_cert: x509.Certificate # PCK Leaf certificate + intermediate_cert: x509.Certificate # Intermediate CA certificate + root_cert: x509.Certificate # Root CA certificate + + +def extract_pck_cert_chain(quote: QuoteV4) -> PCKCertificateChain: + """ + Extract the PCK certificate chain from the quote. + + The certificate chain is embedded in the quote's certification data + as concatenated PEM certificates: PCK Leaf || Intermediate CA || Root CA. + + Args: + quote: Parsed TDX QuoteV4 + + Returns: + PCKCertificateChain with three certificates + + Raises: + TdxVerificationError: If certificate chain is missing or malformed + """ + pck_chain_data = quote.signed_data.certification_data.get_pck_chain() + cert_pem = pck_chain_data.cert_data + if not cert_pem: + raise TdxVerificationError("PCK certificate chain is empty") + + # Parse concatenated PEM certificates + try: + certs = parse_pem_chain(cert_pem) + except CertificateChainError as e: + raise TdxVerificationError(f"Failed to parse PCK certificate chain: {e}") + + if len(certs) != PCK_CERT_CHAIN_COUNT: + raise TdxVerificationError( + f"PCK certificate chain should contain {PCK_CERT_CHAIN_COUNT} certificates, got {len(certs)}" + ) + + return PCKCertificateChain( + pck_cert=certs[0], + intermediate_cert=certs[1], + root_cert=certs[2], + ) + + +def verify_pck_chain(chain: PCKCertificateChain) -> None: + """ + Verify the PCK certificate chain against Intel SGX Root CA. + + Verification steps: + 1. Verify root cert matches embedded Intel SGX Root CA + 2. Verify certificate chain (validity + signatures) using cryptography library + + Args: + chain: PCK certificate chain from quote + + Raises: + TdxVerificationError: If chain verification fails + """ + certs = [chain.pck_cert, chain.intermediate_cert, chain.root_cert] + try: + verify_intel_chain(certs, "PCK certificate chain") + except CertificateChainError as e: + raise TdxVerificationError(str(e)) + + +def verify_quote_signature(quote: QuoteV4, raw_quote: bytes) -> None: + """ + Verify the quote signature using the attestation key. + + The signature covers SHA256(Header || TdQuoteBody). + + Args: + quote: Parsed QuoteV4 + raw_quote: Original raw quote bytes + + Raises: + TdxVerificationError: If signature verification fails + """ + # Get attestation key (64 bytes = raw P-256 point) + attestation_key_bytes = quote.signed_data.attestation_key + public_key = _bytes_to_p256_pubkey(attestation_key_bytes) + + # Get signature (64 bytes = R || S) + signature_bytes = quote.signed_data.signature + signature_der = _signature_to_der(signature_bytes) + + # Message = Header || TdQuoteBody (bytes 0x00 to 0x278) + message = raw_quote[QUOTE_HEADER_START:QUOTE_BODY_END] + message_hash = hashlib.sha256(message).digest() + + # Verify ECDSA signature over pre-hashed message + # Use Prehashed to avoid double-hashing (we already computed SHA256) + try: + public_key.verify( + signature_der, + message_hash, + ec.ECDSA(Prehashed(hashes.SHA256())) + ) + except InvalidSignature: + raise TdxVerificationError( + "Quote signature verification failed: signature does not match" + ) + + +def verify_qe_report_signature(quote: QuoteV4, pck_cert: x509.Certificate) -> None: + """ + Verify the QE report signature using the PCK leaf certificate. + + Caller must ensure quote.signed_data.certification_data.qe_report_data + is not None before calling this function. + + Args: + quote: Parsed QuoteV4 + pck_cert: PCK leaf certificate from the chain + + Raises: + TdxVerificationError: If signature verification fails + """ + qe_report_data = quote.signed_data.certification_data.qe_report_data + + # Get QE report and signature + qe_report = qe_report_data.qe_report # 384 bytes + qe_signature = qe_report_data.qe_report_signature # 64 bytes + + # Convert signature to DER format + signature_der = _signature_to_der(qe_signature) + + # Verify using PCK certificate's public key + try: + pck_public_key = pck_cert.public_key() + pck_public_key.verify(signature_der, qe_report, ec.ECDSA(hashes.SHA256())) + except InvalidSignature: + raise TdxVerificationError( + "QE report signature verification failed using PCK certificate" + ) + except (TypeError, AttributeError, ValueError, UnsupportedAlgorithm) as e: + raise TdxVerificationError( + f"PCK certificate has unexpected key type: {e}" + ) + + +def verify_qe_report_data_binding(quote: QuoteV4) -> None: + """ + Verify that the QE report data binds to the attestation key. + + This is a CRITICAL security check. The QE report's report_data field + must contain SHA256(attestation_key || qe_auth_data) padded to 64 bytes. + Without this check, an attacker could substitute a different attestation key. + + Caller must ensure quote.signed_data.certification_data.qe_report_data + is not None before calling this function. + + Args: + quote: Parsed QuoteV4 + + Raises: + TdxVerificationError: If binding verification fails + """ + qe_report_data = quote.signed_data.certification_data.qe_report_data + + # Get components + attestation_key = quote.signed_data.attestation_key # 64 bytes + qe_auth_data = qe_report_data.qe_auth_data # Variable + qe_report_data_field = qe_report_data.qe_report_parsed.report_data # 64 bytes + + # Compute expected: SHA256(attestation_key || qe_auth_data) || zeros + data_to_hash = attestation_key + qe_auth_data + expected_hash = hashlib.sha256(data_to_hash).digest() + expected_report_data = expected_hash + b'\x00' * SHA256_HASH_SIZE + + # Both values are public attestation data; no timing side-channel concern. + if qe_report_data_field != expected_report_data: + raise TdxVerificationError( + "QE report data binding verification failed: " + "SHA256(attestation_key || auth_data) does not match QE report data. " + "The attestation key may have been tampered with." + ) + + +def _bytes_to_p256_pubkey(key_bytes: bytes) -> ec.EllipticCurvePublicKey: + """ + Convert raw 64-byte P-256 public key to cryptography public key object. + + The raw format is X || Y (32 bytes each). + + Args: + key_bytes: 64 bytes representing X || Y coordinates + + Returns: + EllipticCurvePublicKey object + + Raises: + TdxVerificationError: If key format is invalid + """ + if len(key_bytes) != ATTESTATION_KEY_SIZE: + raise TdxVerificationError( + f"Attestation key is {len(key_bytes)} bytes, expected {ATTESTATION_KEY_SIZE}" + ) + + # Convert to uncompressed point format (0x04 || X || Y) + uncompressed = b'\x04' + key_bytes + + try: + return ec.EllipticCurvePublicKey.from_encoded_point( + ec.SECP256R1(), uncompressed + ) + except (ValueError, TypeError) as e: + raise TdxVerificationError(f"Invalid attestation key: {e}") + + +def _signature_to_der(sig_bytes: bytes) -> bytes: + """ + Convert raw R||S signature to DER format. + + TDX signatures are 64 bytes: R (32 bytes) || S (32 bytes). + DER format is required by cryptography library. + + Args: + sig_bytes: 64-byte raw signature + + Returns: + DER-encoded signature + + Raises: + TdxVerificationError: If signature format is invalid + """ + if len(sig_bytes) != SIGNATURE_SIZE: + raise TdxVerificationError( + f"Signature is {len(sig_bytes)} bytes, expected {SIGNATURE_SIZE}" + ) + + r = int.from_bytes(sig_bytes[0:ECDSA_P256_COMPONENT_SIZE], byteorder='big') + s = int.from_bytes(sig_bytes[ECDSA_P256_COMPONENT_SIZE:SIGNATURE_SIZE], byteorder='big') + + return encode_dss_signature(r, s) + + +def verify_tdx_quote(quote: QuoteV4, raw_quote: bytes) -> PCKCertificateChain: + """ + Perform full cryptographic verification of a TDX quote. + + This is the main entry point for TDX verification. It performs all + security-critical checks in order: + + 1. Extract and verify PCK certificate chain + 2. Verify quote signature using attestation key + 3. Verify QE report signature using PCK certificate + 4. Verify QE report data binding (critical security check) + + Args: + quote: Parsed QuoteV4 structure + raw_quote: Original raw quote bytes (needed for signature verification) + + Returns: + PCKCertificateChain on success (for further use in TCB checks) + + Raises: + TdxVerificationError: If any verification step fails + """ + # Step 1: Extract and verify PCK certificate chain + chain = extract_pck_cert_chain(quote) + verify_pck_chain(chain) + + # Step 2: Verify quote signature + verify_quote_signature(quote, raw_quote) + + # Step 3 & 4 require QE report data + if quote.signed_data.certification_data.qe_report_data is None: + raise TdxVerificationError("QE report data is missing") + + # Step 3: Verify QE report signature + verify_qe_report_signature(quote, chain.pck_cert) + + # Step 4: Verify QE report data binding (CRITICAL) + verify_qe_report_data_binding(quote) + + return chain diff --git a/src/tinfoil/client.py b/src/tinfoil/client.py index 124baa1..8462b82 100644 --- a/src/tinfoil/client.py +++ b/src/tinfoil/client.py @@ -1,19 +1,21 @@ import http.client import json import ssl +import urllib.error import urllib.request import httpx import random from dataclasses import dataclass from typing import Dict, Optional -from urllib.parse import urlparse +from urllib.parse import urlparse, urlencode import cryptography.x509 from cryptography.hazmat.primitives.serialization import PublicFormat, Encoding import hashlib -from .attestation import fetch_attestation +from .attestation import fetch_attestation, TDX_TYPES +from .attestation.attestation_tdx import verify_tdx_hardware from .github import fetch_latest_digest, fetch_attestation_bundle -from .sigstore import verify_attestation +from .sigstore import verify_attestation, fetch_latest_hardware_measurements @dataclass @@ -114,14 +116,14 @@ def wrap_socket(*args, **kwargs) -> ssl.SSLSocket: sock = ssl.create_default_context().wrap_socket(*args, **kwargs) cert_binary = sock.getpeercert(binary_form=True) if not cert_binary: - raise Exception("No certificate found") + raise ValueError("No certificate found") cert = cryptography.x509.load_der_x509_certificate(cert_binary) pub_der = cert.public_key().public_bytes( Encoding.DER, PublicFormat.SubjectPublicKeyInfo ) pk_fp = hashlib.sha256(pub_der).hexdigest() if pk_fp != expected_fp: - raise Exception(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}") + raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}") return sock return wrap_socket @@ -152,8 +154,9 @@ def verify(self) -> GroundTruth: if expected_snp_measurement is None: raise ValueError("snp_measurement not found in provided measurement") - # Get the actual measurement from the attestation - actual_measurement = verification.measurement.registers[0] # SNP measurement is the first register + if not verification.measurement.registers: + raise ValueError("No measurement registers found in attestation") + actual_measurement = verification.measurement.registers[0] if actual_measurement != expected_snp_measurement: raise ValueError(f"SNP measurement mismatch: expected {expected_snp_measurement}, got {actual_measurement}") @@ -169,17 +172,20 @@ def verify(self) -> GroundTruth: # GitHub-based verification digest = fetch_latest_digest(self.repo) sigstore_bundle = fetch_attestation_bundle(self.repo, digest) - + code_measurements = verify_attestation( - sigstore_bundle, - digest, + sigstore_bundle, + digest, self.repo ) - - # Verify measurements match - for (i, code_measurement) in enumerate(code_measurements.registers): - if code_measurement != verification.measurement.registers[i]: - raise ValueError("Code measurements do not match") + + # For TDX, verify hardware measurements (MRTD and RTMR0) + if verification.measurement.type in TDX_TYPES: + hardware_measurements = fetch_latest_hardware_measurements() + verify_tdx_hardware(hardware_measurements, verification.measurement) + + # Verify measurements match (handles cross-platform comparison) + code_measurements.assert_equal(verification.measurement) # Build ground truth from the verified attestation self._ground_truth = GroundTruth( @@ -231,17 +237,26 @@ def get(self, url: str, headers: Dict[str, str] = {}) -> Response: ) return self.make_request(req) -def get_router_address() -> str: +def get_router_address(platform: Optional[str] = None) -> str: """ Fetches the list of available routers from the ATC API and returns a randomly selected address. + + Args: + platform: Optional platform filter (e.g. "snp", "tdx"). + If None, returns routers for any platform. """ + routers_url = "https://atc.tinfoil.sh/routers" + if platform: + routers_url += "?" + urlencode({"platform": platform}) - routers_url = "https://atc.tinfoil.sh/routers?platform=snp" + try: + with urllib.request.urlopen(routers_url, timeout=15) as response: + routers = json.loads(response.read().decode('utf-8')) + except (urllib.error.URLError, json.JSONDecodeError) as e: + raise ValueError(f"Failed to fetch router addresses: {e}") from e - with urllib.request.urlopen(routers_url) as response: - routers = json.loads(response.read().decode('utf-8')) - if len(routers) == 0: - raise ValueError("No routers found in the response") - - return random.choice(routers) + if not isinstance(routers, list) or len(routers) == 0: + raise ValueError("No routers found in the response") + + return random.choice(routers) diff --git a/src/tinfoil/github.py b/src/tinfoil/github.py index c46db23..f263528 100644 --- a/src/tinfoil/github.py +++ b/src/tinfoil/github.py @@ -8,8 +8,14 @@ # --- Cache Setup --- _GITHUB_CACHE_DIR = platformdirs.user_cache_dir("tinfoil", "tinfoil") -# Create a subdirectory specific to GitHub attestations within the main cache -os.makedirs(_GITHUB_CACHE_DIR, exist_ok=True) + + +def _ensure_github_cache_dir() -> None: + """Create the GitHub cache directory if it doesn't exist. Silently ignores failures.""" + try: + os.makedirs(_GITHUB_CACHE_DIR, exist_ok=True) + except OSError: + pass def _attestation_bundle_cache_path(repo: str, digest: str) -> str: """Generate a safe filepath for the attestation bundle cache.""" @@ -33,12 +39,14 @@ def fetch_latest_digest(repo: str) -> str: Exception: If there's any error fetching or parsing the data """ url = f"https://api-github-proxy.tinfoil.sh/repos/{repo}/releases/latest" - release_response = requests.get(url) + release_response = requests.get(url, timeout=15) release_response.raise_for_status() response_data = json.loads(release_response.content) + if not isinstance(response_data, dict) or "tag_name" not in response_data: + raise ValueError(f"Unexpected release API response for {repo}: missing 'tag_name'") tag_name = response_data["tag_name"] - body = response_data["body"] + body = response_data.get("body") or "" # Backwards compatibility for old EIF releases eif_regex = re.compile(r'EIF hash: ([a-fA-F0-9]{64})') @@ -54,7 +62,7 @@ def fetch_latest_digest(repo: str) -> str: # Fallback option: fetch digest from github special endpoint digest_url = f"https://github-proxy.tinfoil.sh/{repo}/releases/download/{tag_name}/tinfoil.hash" - response = requests.get(digest_url) + response = requests.get(digest_url, timeout=15) if response.status_code != 200: raise Exception(f"Failed to fetch attestation digest: {response.status_code} {response.reason}") return response.text.strip() @@ -115,15 +123,17 @@ def fetch_attestation_bundle(repo: str, digest: str) -> bytes: # Encode the string to bytes for file writing and return value bundle_bytes_to_write = bundle_json_string.encode('utf-8') - # Write to cache + # Atomic write to cache + _ensure_github_cache_dir() try: - with open(cache_path, 'wb') as f: + tmp_path = cache_path + ".tmp" + with open(tmp_path, 'wb') as f: f.write(bundle_bytes_to_write) - except OSError as e: - # Don't fail the whole operation if cache write fails, just warn - print(f"Warning: Failed to write cache file {cache_path}: {e}", file=sys.stderr) + os.replace(tmp_path, cache_path) + except OSError: + pass - return bundle_json_string + return bundle_bytes_to_write except (KeyError, IndexError, TypeError) as e: raise Exception(f"Invalid attestation response format from {url}: {e}. Response: {response_data}") from e diff --git a/src/tinfoil/sigstore.py b/src/tinfoil/sigstore.py index dfd7446..495c0b2 100644 --- a/src/tinfoil/sigstore.py +++ b/src/tinfoil/sigstore.py @@ -5,7 +5,9 @@ import json import re -from .attestation import Measurement, PredicateType +from typing import List +from .attestation import Measurement, PredicateType, HardwareMeasurement +from .github import fetch_latest_digest, fetch_attestation_bundle OIDC_ISSUER = "https://token.actions.githubusercontent.com" @@ -31,6 +33,54 @@ def verify(self, cert: Certificate) -> None: f"({_OIDC_GITHUB_WORKFLOW_REF_OID.dotted_string}) extension" ) + +def _verify_dsse_bundle(bundle_json: bytes, digest: str, repo: str) -> dict: + """ + Verify a Sigstore DSSE bundle and return the parsed in-toto payload. + + Performs signature verification, certificate identity policy checks, + Rekor log consistency, payload type validation, and digest matching. + + Args: + bundle_json: Raw Sigstore bundle JSON + digest: Expected SHA256 hex digest of the DSSE payload subject + repo: GitHub repository (e.g. "tinfoilsh/confidential-router") + + Returns: + Parsed in-toto statement dict with predicateType, predicate, subject, etc. + + Raises: + ValueError: If any verification step fails + """ + verifier = Verifier.production() + bundle = Bundle.from_json(bundle_json) + + policy = AllOf([ + OIDCIssuer(OIDC_ISSUER), + GitHubWorkflowRepository(repo), + GitHubWorkflowRefPattern("refs/tags/.*") + ]) + + payload_type, payload_bytes = verifier.verify_dsse(bundle, policy) + + if payload_type != 'application/vnd.in-toto+json': + raise ValueError(f"Unsupported payload type: {payload_type}") + + statement = json.loads(payload_bytes) + + subjects = statement.get("subject", []) + if not subjects or "digest" not in subjects[0] or "sha256" not in subjects[0].get("digest", {}): + raise ValueError("Invalid in-toto statement: missing or empty subject") + + if digest != subjects[0]["digest"]["sha256"]: + raise ValueError( + f"Digest mismatch: expected {digest}, " + f"got {subjects[0]['digest']['sha256']}" + ) + + return statement + + def verify_attestation(bundle_json: bytes, digest: str, repo: str) -> Measurement: """ Verifies the attested measurements of an enclave image against a trusted root (Sigstore) @@ -48,45 +98,26 @@ def verify_attestation(bundle_json: bytes, digest: str, repo: str) -> Measuremen ValueError: If verification fails or digests don't match """ try: - # Create verifier with the trusted root - verifier = Verifier.production() - - # Parse the bundle - bundle = Bundle.from_json(bundle_json) - - # Create verification policy for GitHub Actions certificate identity - policy = AllOf([ - OIDCIssuer(OIDC_ISSUER), - GitHubWorkflowRepository(repo), - GitHubWorkflowRefPattern("refs/tags/.*") - ]) - - # --- Core DSSE Verification --- - # This verifies the signature on the DSSE envelope, applies the - # certificate identity policy, and checks Rekor log consistency. - # It returns the verified payload from within the envelope. - payload_type, payload_bytes = verifier.verify_dsse(bundle, policy) - - # --- Process the Verified Payload --- - if payload_type != 'application/vnd.in-toto+json': - raise ValueError(f"Unsupported payload type: {payload_type}. Only supports In-toto.") - - result_json = json.loads(payload_bytes) - predicate_type = PredicateType(result_json["predicateType"]) - predicate_fields = result_json["predicate"] - - # --- Manual Payload Digest Verification --- - # Now, verify that the provided external digest matches the - # actual digest in the payload returned from the verified envelope. - if digest != result_json["subject"][0]["digest"]["sha256"]: - raise ValueError( - f"Provided digest does not match verified DSSE payload digest. " - f"Expected: {digest}, Got: {result_json['subject'][0]['digest']['sha256']}" - ) - - # Convert predicate type to measurement type + statement = _verify_dsse_bundle(bundle_json, digest, repo) + + predicate_type = PredicateType(statement["predicateType"]) + predicate_fields = statement["predicate"] + if predicate_type == PredicateType.SNP_TDX_MULTIPLATFORM_v1: - registers = [predicate_fields["snp_measurement"]] + snp_measurement = predicate_fields.get("snp_measurement") + if not snp_measurement: + raise ValueError("Invalid multiplatform measurement: no snp_measurement") + + tdx_measurement = predicate_fields.get("tdx_measurement") + if not tdx_measurement: + raise ValueError("Invalid multiplatform measurement: no tdx_measurement") + + rtmr1 = tdx_measurement.get("rtmr1") + rtmr2 = tdx_measurement.get("rtmr2") + if not rtmr1 or not rtmr2: + raise ValueError("Invalid multiplatform measurement: missing rtmr1 or rtmr2") + + registers = [snp_measurement, rtmr1, rtmr2] else: raise ValueError(f"Unsupported predicate type: {predicate_type}") @@ -94,6 +125,74 @@ def verify_attestation(bundle_json: bytes, digest: str, repo: str) -> Measuremen type=predicate_type, registers=registers ) - + except Exception as e: raise ValueError(f"Attestation processing failed: {e}") from e + + +HARDWARE_MEASUREMENTS_REPO = "tinfoilsh/hardware-measurements" + + +def verify_hardware_measurements(bundle_json: bytes, digest: str, repo: str) -> List[HardwareMeasurement]: + """ + Verifies hardware measurements from a Sigstore bundle. + + Args: + bundle_json: The bundle JSON data (bytes) + digest: The expected hex-encoded SHA256 digest + repo: The repository name + + Returns: + List of HardwareMeasurement objects + + Raises: + ValueError: If verification fails or predicate type is unexpected + """ + try: + statement = _verify_dsse_bundle(bundle_json, digest, repo) + + predicate_type = statement["predicateType"] + if predicate_type != PredicateType.HARDWARE_MEASUREMENTS_V1.value: + raise ValueError(f"Unexpected predicate type: {predicate_type}") + + predicate_fields = statement["predicate"] + measurements = [] + + for platform_id, platform_data in predicate_fields.items(): + if not isinstance(platform_data, dict): + raise ValueError(f"Invalid hardware measurement for {platform_id}") + + mrtd = platform_data.get("mrtd") + rtmr0 = platform_data.get("rtmr0") + + if not mrtd or not rtmr0: + raise ValueError(f"Invalid hardware measurement for {platform_id}: missing mrtd or rtmr0") + + measurements.append(HardwareMeasurement( + id=f"{platform_id}@{digest}", + mrtd=mrtd, + rtmr0=rtmr0, + )) + + return measurements + + except Exception as e: + raise ValueError(f"Hardware measurements processing failed: {e}") from e + + +def fetch_latest_hardware_measurements() -> List[HardwareMeasurement]: + """ + Fetches the latest hardware measurements from GitHub + Sigstore. + + Returns: + List of HardwareMeasurement objects + + Raises: + ValueError: If fetching or verification fails + """ + try: + digest = fetch_latest_digest(HARDWARE_MEASUREMENTS_REPO) + bundle_json = fetch_attestation_bundle(HARDWARE_MEASUREMENTS_REPO, digest) + return verify_hardware_measurements(bundle_json, digest, HARDWARE_MEASUREMENTS_REPO) + except Exception as e: + raise ValueError(f"Hardware measurements fetching failed: {e}") from e diff --git a/tests/test_attestation_all_enclaves.py b/tests/test_attestation_all_enclaves.py new file mode 100644 index 0000000..1ee8700 --- /dev/null +++ b/tests/test_attestation_all_enclaves.py @@ -0,0 +1,189 @@ +""" +Integration test for all live enclaves from the config file. + +Fetches config from GitHub, extracts all model enclaves, and verifies +attestation for each one. + +Configure via environment variables: + TINFOIL_CONFIG_URL - URL to config.yml (default: main branch) + TINFOIL_API_KEY - API key for model tests (optional) + +Example: + python -m pytest tests/test_all_enclaves.py -v -s + python -m pytest tests/test_all_enclaves.py -v -s -k "llama" # filter by model name +""" + +import os +import pytest +import requests + +try: + import yaml + HAS_YAML = True +except ImportError: + HAS_YAML = False + +from tinfoil.client import SecureClient +from tinfoil.attestation import PredicateType + +pytestmark = pytest.mark.integration + +# Config URL +CONFIG_URL = os.environ.get( + "TINFOIL_CONFIG_URL", + "https://raw.githubusercontent.com/tinfoilsh/confidential-model-router/refs/heads/main/config.yml" +) + +# Models to skip (not testable) +SKIP_MODELS = ["websearch"] + + +def fetch_config() -> dict: + """Fetch and parse the config file.""" + if not HAS_YAML: + raise ImportError("PyYAML not installed. Install with: pip install pyyaml") + response = requests.get(CONFIG_URL, timeout=30) + response.raise_for_status() + return yaml.safe_load(response.text) + + +def get_all_enclaves() -> list[tuple[str, str, str]]: + """ + Get all enclaves from config. + + Returns: + List of (model_name, repo, hostname) tuples + + Raises: + ImportError: If PyYAML is not installed + Exception: If the config cannot be fetched or parsed + """ + if not HAS_YAML: + raise ImportError("PyYAML not installed. Install with: pip install pyyaml") + + config = fetch_config() + + enclaves = [] + models = config.get("models", {}) + + for model_name, model_config in models.items(): + if any(skip.lower() in model_name.lower() for skip in SKIP_MODELS): + continue + + repo = model_config.get("repo", "") + hostnames = model_config.get("enclaves", []) or model_config.get("hostnames", []) + + if not repo or not hostnames: + continue + + for hostname in hostnames: + enclaves.append((model_name, repo, hostname)) + + return enclaves + + +def pytest_generate_tests(metafunc): + """Defer network call to test generation time instead of module import.""" + if "enclave_config" in metafunc.fixturenames: + try: + enclaves = get_all_enclaves() + except ImportError: + enclaves = [] + except Exception: + enclaves = [] + + if not enclaves: + enclaves = [("__skip__", "__skip__", "__skip__")] + ids = ["no_enclaves"] + else: + ids = [f"{n}@{h}" for n, _, h in enclaves] + + metafunc.parametrize("enclave_config", enclaves, ids=ids) + + +def test_enclave_attestation(enclave_config): + """ + Test attestation for a single enclave. + + Verifies: + 1. Can connect to enclave + 2. Attestation verification passes (crypto + policy) + 3. Sigstore verification passes + 4. Measurements match + 5. For TDX: hardware measurements verified + """ + model_name, repo, hostname = enclave_config + + if model_name == "__skip__": + if not HAS_YAML: + pytest.skip("PyYAML not installed. Install with: pip install pyyaml") + else: + pytest.skip("No enclaves found in config") + + print(f"\n{'='*60}") + print(f"Testing: {model_name}") + print(f" Enclave: {hostname}") + print(f" Repo: {repo}") + print(f"{'='*60}") + + try: + client = SecureClient(enclave=hostname, repo=repo) + ground_truth = client.verify() + + # Print results + measurement_type = ground_truth.measurement.type + print(f"\n✓ Attestation verified!") + print(f" Architecture: {measurement_type.value}") + print(f" Fingerprint: {ground_truth.measurement.fingerprint()[:32]}...") + print(f" Public key: {ground_truth.public_key[:32]}...") + print(f" Digest: {ground_truth.digest[:32]}...") + + # Print architecture-specific info + regs = ground_truth.measurement.registers + if measurement_type == PredicateType.SEV_GUEST_V2: + print(f" SNP measurement: {regs[0][:32]}...") + elif measurement_type == PredicateType.TDX_GUEST_V2: + print(f" MRTD: {regs[0][:32]}...") + print(f" RTMR0: {regs[1][:32]}...") + + except Exception as e: + pytest.fail(f"Attestation failed for {model_name}@{hostname}: {e}") + + +def test_summary(): + """Print summary of all enclaves that will be tested.""" + if not HAS_YAML: + pytest.skip("PyYAML not installed. Install with: pip install pyyaml") + + try: + enclaves = get_all_enclaves() + except Exception as e: + pytest.fail(f"Failed to fetch enclave config: {e}") + + if not enclaves: + pytest.skip("No enclaves found in config") + + print(f"\n{'='*60}") + print(f"ENCLAVE SUMMARY") + print(f"{'='*60}") + print(f"Config: {CONFIG_URL}") + print(f"Total enclaves: {len(enclaves)}") + print(f"\nModels:") + + # Group by model + models = {} + for name, repo, host in enclaves: + if name not in models: + models[name] = {"repo": repo, "hosts": []} + models[name]["hosts"].append(host) + + for name, info in sorted(models.items()): + print(f"\n {name}:") + print(f" Repo: {info['repo']}") + print(f" Enclaves: {len(info['hosts'])}") + for host in info['hosts']: + print(f" - {host}") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_attestation_flow.py b/tests/test_attestation_flow.py index 2f8cf28..750c911 100644 --- a/tests/test_attestation_flow.py +++ b/tests/test_attestation_flow.py @@ -1,89 +1,97 @@ +""" +Integration test for attestation verification flow. + +Tests the complete verification using the SecureClient API with a live router. +Works with any architecture (SNP or TDX) returned by the router service. +""" + import pytest -# Adjust these imports based on your project structure -from tinfoil.github import fetch_latest_digest, fetch_attestation_bundle -from tinfoil.sigstore import verify_attestation -from tinfoil.attestation import fetch_attestation -from tinfoil.client import get_router_address +from tinfoil.client import SecureClient, get_router_address +from tinfoil.attestation import PredicateType pytestmark = pytest.mark.integration # allows pytest -m integration filtering -# Fetch config from environment variables, falling back to defaults -# Use the same env vars as the other integration test for consistency +# Router always runs confidential-model-router REPO = "tinfoilsh/confidential-model-router" + def test_full_verification_flow(): """ - Tests the complete attestation verification flow: - 1. Fetch latest digest for the repository. - 2. Fetch the sigstore attestation bundle for that digest. - 3. Verify the sigstore bundle to get code measurements. - 4. Fetch the runtime attestation from the enclave. - 5. Verify the runtime attestation. - 6. Compare code measurements with runtime measurements. + Tests the complete attestation verification flow using SecureClient. + + Gets a router from the ATC service and verifies it against the + confidential-model-router repo. Works with any TEE type (SNP or TDX). + + SecureClient.verify() performs: + 1. Fetch runtime attestation from enclave + 2. Verify attestation (cryptographic + policy validation) + 3. Fetch latest digest from GitHub + 4. Fetch and verify sigstore attestation bundle + 5. For TDX: verify hardware measurements (MRTD, RTMR0) + 6. Compare code measurements with runtime measurements """ try: - # Fetch enclave address lazily inside the test to avoid import-time network calls - try: - enclave = get_router_address() - except Exception as e: - pytest.skip(f"Could not fetch router address from ATC service: {e}") - return - - # Fetch latest release digest - print(f"Fetching latest release for {REPO}") - digest = fetch_latest_digest(REPO) - print(f"Found digest: {digest}") - - # Fetch attestation bundle - print(f"Fetching attestation bundle for {REPO}@{digest}") - sigstore_bundle = fetch_attestation_bundle(REPO, digest) - assert sigstore_bundle is not None # Basic check - - # Verify attested measurements from sigstore bundle - print(f"Verifying attested measurements for {REPO}@{digest}") - code_measurements = verify_attestation( - sigstore_bundle, - digest, - REPO - ) - assert code_measurements is not None # Basic check - print(f"Code measurements fingerprint: {code_measurements.fingerprint()}") - - - # Fetch runtime attestation from the enclave - print(f"Fetching runtime attestation from {enclave}") - enclave_attestation = fetch_attestation(enclave) - assert enclave_attestation is not None # Basic check - - # Verify enclave measurements from runtime attestation - print("Verifying enclave measurements") - runtime_verification = enclave_attestation.verify() - assert runtime_verification is not None # Basic check - print(f"Runtime measurement fingerprint: {runtime_verification.measurement.fingerprint()}") - print(f"Public key fingerprint: {runtime_verification.public_key_fp}") - - - # Compare measurements - print("Comparing measurements") - assert len(code_measurements.registers) == len(runtime_verification.measurement.registers), \ - "Number of measurement registers differ" - - for i, code_reg in enumerate(code_measurements.registers): - runtime_reg = runtime_verification.measurement.registers[i] - assert code_reg == runtime_reg, \ - f"Measurement register {i} mismatch: Code='{code_reg}' vs Runtime='{runtime_reg}'" - - print("Verification successful!") - print(f"Public key fingerprint: {runtime_verification.public_key_fp}") - print(f"Measurement: {code_measurements.fingerprint()}") - + enclave = get_router_address() except Exception as e: - import traceback - traceback.print_exc() - pytest.fail(f"Verification flow failed with exception: {e}") + pytest.skip(f"Could not fetch router address from ATC service: {e}") + + print(f"\nVerifying enclave: {enclave}") + print(f"Against repo: {REPO}") + + client = SecureClient(enclave=enclave, repo=REPO) + ground_truth = client.verify() + + # Print architecture-specific info + measurement_type = ground_truth.measurement.type + print(f"\n✓ Verification successful!") + print(f" Architecture: {measurement_type.value}") + print(f" Measurement fingerprint: {ground_truth.measurement.fingerprint()}") + print(f" Public key fingerprint: {ground_truth.public_key}") + print(f" Digest: {ground_truth.digest}") + + # Print registers based on type + regs = ground_truth.measurement.registers + if measurement_type == PredicateType.SEV_GUEST_V2: + print(f"\n SNP Measurement: {regs[0][:32]}...") + elif measurement_type == PredicateType.TDX_GUEST_V2: + print(f"\n TDX Measurements:") + print(f" MRTD: {regs[0][:32]}...") + print(f" RTMR0: {regs[1][:32]}...") + print(f" RTMR1: {regs[2][:32]}...") + print(f" RTMR2: {regs[3][:32]}...") + print(f" RTMR3: {regs[4][:32]}...") + + +def test_secure_http_client(): + """ + Tests that SecureClient creates a working pinned HTTP client + and that TLS pinning is exercised by issuing an actual request. + Works with any TEE type (SNP or TDX). + """ + try: + enclave = get_router_address() + except Exception as e: + pytest.skip(f"Could not fetch router address from ATC service: {e}") + + print(f"\nCreating secure HTTP client for: {enclave}") + + client = SecureClient(enclave=enclave, repo=REPO) + http_client = client.make_secure_http_client() + + ground_truth = client.ground_truth + assert ground_truth is not None + + try: + response = http_client.get(f"https://{enclave}/.well-known/tinfoil-attestation") + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + print(f"\n✓ TLS-pinned request succeeded (status {response.status_code})") + finally: + http_client.close() + + print(f" Architecture: {ground_truth.measurement.type.value}") + print(f" TLS pinned to: {ground_truth.public_key}") if __name__ == "__main__": - # Allow running the test directly using `python tests/test_verification_flow.py` - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_collateral_tdx.py b/tests/test_collateral_tdx.py new file mode 100644 index 0000000..0735e70 --- /dev/null +++ b/tests/test_collateral_tdx.py @@ -0,0 +1,1892 @@ +""" +Unit tests for TDX collateral fetching and TCB validation. +""" + +import json +import pytest +from datetime import datetime, timezone, timedelta +from unittest.mock import patch, MagicMock + +from tinfoil.attestation.collateral_tdx import ( + TcbStatus, + TcbComponent, + Tcb, + TcbLevel, + TcbInfo, + TdxTcbInfo, + TdxModuleIdentity, + EnclaveIdentity, + QeIdentity, + TdxCollateral, + CollateralError, + parse_tcb_info_response, + parse_qe_identity_response, + is_cpu_svn_higher_or_equal, + is_tdx_tcb_svn_higher_or_equal, + get_matching_tcb_level, + get_matching_qe_tcb_level, + get_tdx_module_identity, + validate_tcb_status, + validate_tdx_module_identity, + validate_qe_identity, + check_collateral_freshness, + fetch_tcb_info, + fetch_qe_identity, + _parse_datetime, + _parse_issuer_chain_header, + _verify_collateral_signature, + _parse_hex_bytes, + _get_tcb_info_cache_path, + _get_qe_identity_cache_path, + _is_tcb_info_fresh, + _is_qe_identity_fresh, + _read_cache, + _write_cache, +) +from tinfoil.attestation.pck_extensions import PckExtensions, PckCertTCB + + +# ============================================================================= +# Sample Data - Based on Real Intel PCS Responses +# ============================================================================= + +SAMPLE_TCB_INFO_JSON = """ +{ + "tcbInfo": { + "id": "TDX", + "version": 3, + "issueDate": "2025-12-17T06:24:56Z", + "nextUpdate": "2026-01-16T06:24:56Z", + "fmspc": "90c06f000000", + "pceId": "0000", + "tcbType": 0, + "tcbEvaluationDataNumber": 18, + "tdxModule": { + "mrsigner": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + "attributes": "0000000000000000", + "attributesMask": "FFFFFFFFFFFFFFFF" + }, + "tdxModuleIdentities": [ + { + "id": "TDX_03", + "mrsigner": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + "attributes": "0000000000000000", + "attributesMask": "FFFFFFFFFFFFFFFF", + "tcbLevels": [ + { + "tcb": { "isvsvn": 3 }, + "tcbDate": "2024-11-13T00:00:00Z", + "tcbStatus": "UpToDate" + } + ] + } + ], + "tcbLevels": [ + { + "tcb": { + "sgxtcbcomponents": [ + {"svn": 3}, {"svn": 3}, {"svn": 2}, {"svn": 2}, + {"svn": 2}, {"svn": 1}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0} + ], + "pcesvn": 13, + "tdxtcbcomponents": [ + {"svn": 5, "category": "TDX Module", "type": "TDX Module"}, + {"svn": 0}, {"svn": 3}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0} + ] + }, + "tcbDate": "2024-11-13T00:00:00Z", + "tcbStatus": "UpToDate" + }, + { + "tcb": { + "sgxtcbcomponents": [ + {"svn": 2}, {"svn": 2}, {"svn": 2}, {"svn": 2}, + {"svn": 2}, {"svn": 1}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0} + ], + "pcesvn": 11, + "tdxtcbcomponents": [ + {"svn": 4}, {"svn": 0}, {"svn": 2}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0}, + {"svn": 0}, {"svn": 0}, {"svn": 0}, {"svn": 0} + ] + }, + "tcbDate": "2024-03-13T00:00:00Z", + "tcbStatus": "OutOfDate" + } + ] + }, + "signature": "abcd1234" +} +""" + +SAMPLE_QE_IDENTITY_JSON = """ +{ + "enclaveIdentity": { + "id": "TD_QE", + "version": 2, + "issueDate": "2025-12-17T18:48:11Z", + "nextUpdate": "2026-01-16T18:48:11Z", + "tcbEvaluationDataNumber": 18, + "miscselect": "00000000", + "miscselectMask": "FFFFFFFF", + "attributes": "11000000000000000000000000000000", + "attributesMask": "FBFFFFFFFFFFFFFF0000000000000000", + "mrsigner": "DC9E2A7C6F948F17474E34A7FC43ED030F7C1563F1BABDDF6340C82E0E54A8C5", + "isvprodid": 2, + "tcbLevels": [ + { + "tcb": { "isvsvn": 4 }, + "tcbDate": "2024-11-13T00:00:00Z", + "tcbStatus": "UpToDate" + }, + { + "tcb": { "isvsvn": 3 }, + "tcbDate": "2024-03-13T00:00:00Z", + "tcbStatus": "OutOfDate" + } + ] + }, + "signature": "0665a932" +} +""" + +# Byte versions for cache testing +SAMPLE_TCB_INFO_RESPONSE = SAMPLE_TCB_INFO_JSON.encode() +SAMPLE_QE_IDENTITY_RESPONSE = SAMPLE_QE_IDENTITY_JSON.encode() + +# Stale TCB Info (next_update in the past) +SAMPLE_STALE_TCB_INFO_JSON = """ +{ + "tcbInfo": { + "id": "TDX", + "version": 3, + "issueDate": "2024-01-17T06:24:56Z", + "nextUpdate": "2024-02-16T06:24:56Z", + "fmspc": "90c06f000000", + "pceId": "0000", + "tcbType": 0, + "tcbEvaluationDataNumber": 18, + "tdxModuleIdentities": [], + "tcbLevels": [] + }, + "signature": "stale1234" +} +""" +SAMPLE_STALE_TCB_INFO_RESPONSE = SAMPLE_STALE_TCB_INFO_JSON.encode() + + +# ============================================================================= +# Helper Functions +# ============================================================================= + +def create_sample_pck_extensions() -> PckExtensions: + """Create sample PCK extensions for testing.""" + return PckExtensions( + ppid="00" * 16, + tcb=PckCertTCB( + pce_svn=13, + cpu_svn=bytes([3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + tcb_components=[3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ), + pceid="0000", + fmspc="90c06f000000", + ) + + +# ============================================================================= +# Parsing Tests +# ============================================================================= + +class TestParseTcbInfoResponse: + """Test TCB Info parsing.""" + + def test_parse_valid_response(self): + """Test parsing a valid TCB Info response.""" + result = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + assert result.tcb_info.id == "TDX" + assert result.tcb_info.version == 3 + assert result.tcb_info.fmspc == "90c06f000000" + assert len(result.tcb_info.tcb_levels) == 2 + assert result.signature == "abcd1234" + + def test_parse_tcb_levels(self): + """Test that TCB levels are parsed correctly.""" + result = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + first_level = result.tcb_info.tcb_levels[0] + assert first_level.tcb_status == TcbStatus.UP_TO_DATE + assert first_level.tcb.pce_svn == 13 + assert len(first_level.tcb.sgx_tcb_components) == 16 + assert first_level.tcb.sgx_tcb_components[0].svn == 3 + + def test_parse_tdx_module_identities(self): + """Test that TDX module identities are parsed.""" + result = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + assert len(result.tcb_info.tdx_module_identities) == 1 + identity = result.tcb_info.tdx_module_identities[0] + assert identity.id == "TDX_03" + + def test_reject_wrong_id(self): + """Test that non-TDX ID is rejected.""" + bad_json = SAMPLE_TCB_INFO_JSON.replace('"id": "TDX"', '"id": "SGX"') + with pytest.raises(CollateralError, match="must be 'TDX'"): + parse_tcb_info_response(bad_json.encode()) + + def test_reject_wrong_version(self): + """Test that wrong version is rejected.""" + bad_json = SAMPLE_TCB_INFO_JSON.replace('"version": 3', '"version": 2') + with pytest.raises(CollateralError, match="must be 3"): + parse_tcb_info_response(bad_json.encode()) + + +class TestParseQeIdentityResponse: + """Test QE Identity parsing.""" + + def test_parse_valid_response(self): + """Test parsing a valid QE Identity response.""" + result = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + assert result.enclave_identity.id == "TD_QE" + assert result.enclave_identity.version == 2 + assert result.enclave_identity.isv_prod_id == 2 + assert len(result.enclave_identity.tcb_levels) == 2 + + def test_parse_mrsigner(self): + """Test MRSIGNER is parsed correctly.""" + result = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + expected = bytes.fromhex( + "DC9E2A7C6F948F17474E34A7FC43ED030F7C1563F1BABDDF6340C82E0E54A8C5" + ) + assert result.enclave_identity.mrsigner == expected + + def test_reject_wrong_id(self): + """Test that non-TD_QE ID is rejected.""" + bad_json = SAMPLE_QE_IDENTITY_JSON.replace('"id": "TD_QE"', '"id": "QE"') + with pytest.raises(CollateralError, match="must be 'TD_QE'"): + parse_qe_identity_response(bad_json.encode()) + + +# ============================================================================= +# TCB Comparison Tests +# ============================================================================= + +class TestIsCpuSvnHigherOrEqual: + """Test CPU SVN comparison.""" + + def test_equal_values(self): + """Test equal SVN values pass.""" + cpu_svn = bytes([3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + components = [TcbComponent(svn=3), TcbComponent(svn=3), TcbComponent(svn=2), + TcbComponent(svn=2), TcbComponent(svn=2), TcbComponent(svn=1), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_cpu_svn_higher_or_equal(cpu_svn, components) is True + + def test_higher_values(self): + """Test higher SVN values pass.""" + cpu_svn = bytes([4, 4, 3, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + components = [TcbComponent(svn=3), TcbComponent(svn=3), TcbComponent(svn=2), + TcbComponent(svn=2), TcbComponent(svn=2), TcbComponent(svn=1), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_cpu_svn_higher_or_equal(cpu_svn, components) is True + + def test_lower_first_component(self): + """Test lower first component fails.""" + cpu_svn = bytes([2, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + components = [TcbComponent(svn=3), TcbComponent(svn=3), TcbComponent(svn=2), + TcbComponent(svn=2), TcbComponent(svn=2), TcbComponent(svn=1), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_cpu_svn_higher_or_equal(cpu_svn, components) is False + + def test_wrong_length(self): + """Test wrong length fails.""" + cpu_svn = bytes([3, 3, 2]) # Only 3 bytes + components = [TcbComponent(svn=3)] * 16 + + assert is_cpu_svn_higher_or_equal(cpu_svn, components) is False + + +class TestIsTdxTcbSvnHigherOrEqual: + """Test TDX TCB SVN comparison.""" + + def test_equal_values(self): + """Test equal SVN values pass.""" + tee_tcb_svn = bytes([5, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + components = [TcbComponent(svn=5), TcbComponent(svn=0), TcbComponent(svn=3), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_tdx_tcb_svn_higher_or_equal(tee_tcb_svn, components) is True + + def test_skip_first_two_when_module_version_set(self): + """Test that first 2 bytes are skipped when tee_tcb_svn[1] > 0.""" + # tee_tcb_svn[1] = 1, so first 2 bytes should be skipped + tee_tcb_svn = bytes([0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + # Components have higher values in first 2 positions, but should be skipped + components = [TcbComponent(svn=99), TcbComponent(svn=99), TcbComponent(svn=3), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_tdx_tcb_svn_higher_or_equal(tee_tcb_svn, components) is True + + def test_lower_value_fails(self): + """Test lower SVN value fails.""" + tee_tcb_svn = bytes([5, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + components = [TcbComponent(svn=5), TcbComponent(svn=0), TcbComponent(svn=3), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0), TcbComponent(svn=0), TcbComponent(svn=0), + TcbComponent(svn=0)] + + assert is_tdx_tcb_svn_higher_or_equal(tee_tcb_svn, components) is False + + +class TestGetMatchingTcbLevel: + """Test TCB level matching.""" + + def test_find_matching_level(self): + """Test finding a matching TCB level.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + tee_tcb_svn = bytes([5, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + pce_svn = 13 + cpu_svn = bytes([3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + result = get_matching_tcb_level( + tcb_info.tcb_info.tcb_levels, + tee_tcb_svn, + pce_svn, + cpu_svn, + ) + + assert result is not None + assert result.tcb_status == TcbStatus.UP_TO_DATE + + def test_no_matching_level(self): + """Test no matching level found.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + # Very low SVN values - won't match any level + tee_tcb_svn = bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + pce_svn = 1 + cpu_svn = bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + result = get_matching_tcb_level( + tcb_info.tcb_info.tcb_levels, + tee_tcb_svn, + pce_svn, + cpu_svn, + ) + + assert result is None + + +class TestGetMatchingQeTcbLevel: + """Test QE TCB level matching with isvsvn.""" + + def test_find_matching_level(self): + """Test finding a matching QE TCB level by isvsvn.""" + tcb_levels = [ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=8, + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=6, + ), + tcb_date="2023-06-01T00:00:00Z", + tcb_status=TcbStatus.SW_HARDENING_NEEDED, + advisory_ids=["INTEL-SA-00001"], + ), + ] + + # ISV SVN 8 should match first level (equal) + result = get_matching_qe_tcb_level(tcb_levels, 8) + assert result is not None + assert result.tcb_status == TcbStatus.UP_TO_DATE + + # ISV SVN 10 should also match first level (>= 8) + result = get_matching_qe_tcb_level(tcb_levels, 10) + assert result is not None + assert result.tcb_status == TcbStatus.UP_TO_DATE + + # ISV SVN 7 should match second level (>= 6 but < 8) + result = get_matching_qe_tcb_level(tcb_levels, 7) + assert result is not None + assert result.tcb_status == TcbStatus.SW_HARDENING_NEEDED + + def test_no_matching_level(self): + """Test no matching QE TCB level when isvsvn too low.""" + tcb_levels = [ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=8, + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ] + + # ISV SVN 5 is less than 8, so no match + result = get_matching_qe_tcb_level(tcb_levels, 5) + assert result is None + + def test_empty_levels(self): + """Test empty TCB levels list.""" + result = get_matching_qe_tcb_level([], 8) + assert result is None + + def test_level_without_isv_svn(self): + """Test TCB level without isv_svn field is skipped.""" + tcb_levels = [ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=None, # No isv_svn set + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ] + + result = get_matching_qe_tcb_level(tcb_levels, 8) + assert result is None + + +# ============================================================================= +# Validation Tests +# ============================================================================= + +class TestValidateTcbStatus: + """Test TCB status validation.""" + + def test_up_to_date_passes(self): + """Test UpToDate status passes.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + pck_ext = create_sample_pck_extensions() + tee_tcb_svn = bytes([5, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + result = validate_tcb_status( + tcb_info.tcb_info, + tee_tcb_svn, + pck_ext, + ) + + assert result.tcb_status == TcbStatus.UP_TO_DATE + + def test_no_matching_level_fails(self): + """Test that no matching level raises error.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + # Create extensions with very low SVN values + pck_ext = PckExtensions( + ppid="00" * 16, + tcb=PckCertTCB( + pce_svn=1, + cpu_svn=bytes(16), + tcb_components=[0] * 16, + ), + pceid="0000", + fmspc="90c06f000000", + ) + tee_tcb_svn = bytes(16) + + with pytest.raises(CollateralError, match="No matching TCB level"): + validate_tcb_status(tcb_info.tcb_info, tee_tcb_svn, pck_ext) + + def test_fmspc_mismatch_fails(self): + """Test that FMSPC mismatch raises error.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + # Create extensions with different FMSPC + pck_ext = PckExtensions( + ppid="00" * 16, + tcb=PckCertTCB( + pce_svn=13, + cpu_svn=bytes([3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + tcb_components=[3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ), + pceid="0000", + fmspc="aabbcc000000", # Different FMSPC + ) + tee_tcb_svn = bytes([5, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + with pytest.raises(CollateralError, match="FMSPC mismatch"): + validate_tcb_status(tcb_info.tcb_info, tee_tcb_svn, pck_ext) + + def test_pceid_mismatch_fails(self): + """Test that PCE_ID mismatch raises error.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + + # Create extensions with different PCE_ID + pck_ext = PckExtensions( + ppid="00" * 16, + tcb=PckCertTCB( + pce_svn=13, + cpu_svn=bytes([3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + tcb_components=[3, 3, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ), + pceid="1234", # Different PCE_ID + fmspc="90c06f000000", + ) + tee_tcb_svn = bytes([5, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + with pytest.raises(CollateralError, match="PCE_ID mismatch"): + validate_tcb_status(tcb_info.tcb_info, tee_tcb_svn, pck_ext) + + +class TestValidateQeIdentity: + """Test QE identity validation.""" + + def _create_qe_identity(self) -> EnclaveIdentity: + """Create a sample QE Identity for testing.""" + return EnclaveIdentity( + id="TD_QE", + version=2, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + tcb_evaluation_data_number=17, + miscselect=b"\x00\x00\x00\x00", + miscselect_mask=b"\xff\xff\xff\xff", + attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + attributes_mask=b"\xfb\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00", + mrsigner=b"\xdc" * 32, + isv_prod_id=1, + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=8, + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ], + ) + + def test_valid_qe_identity(self): + """Test validation passes with matching QE identity.""" + qe_identity = self._create_qe_identity() + result = validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=8, + qe_report_mrsigner=b"\xdc" * 32, + qe_report_miscselect=b"\x00\x00\x00\x00", + qe_report_attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + qe_report_isvprodid=1, + ) + assert result.tcb_status == TcbStatus.UP_TO_DATE + + def test_mrsigner_mismatch(self): + """Test validation fails with MRSIGNER mismatch.""" + qe_identity = self._create_qe_identity() + with pytest.raises(CollateralError, match="MRSIGNER does not match"): + validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=8, + qe_report_mrsigner=b"\xaa" * 32, # Wrong MRSIGNER + qe_report_miscselect=b"\x00\x00\x00\x00", + qe_report_attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + qe_report_isvprodid=1, + ) + + def test_miscselect_mismatch(self): + """Test validation fails with MISCSELECT mismatch under mask.""" + qe_identity = self._create_qe_identity() + with pytest.raises(CollateralError, match="MISCSELECT does not match"): + validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=8, + qe_report_mrsigner=b"\xdc" * 32, + qe_report_miscselect=b"\x01\x00\x00\x00", # Wrong MISCSELECT + qe_report_attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + qe_report_isvprodid=1, + ) + + def test_attributes_mismatch(self): + """Test validation fails with Attributes mismatch under mask.""" + qe_identity = self._create_qe_identity() + with pytest.raises(CollateralError, match="Attributes do not match"): + validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=8, + qe_report_mrsigner=b"\xdc" * 32, + qe_report_miscselect=b"\x00\x00\x00\x00", + qe_report_attributes=b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", # Wrong + qe_report_isvprodid=1, + ) + + def test_isvprodid_mismatch(self): + """Test validation fails with ISV ProdID mismatch.""" + qe_identity = self._create_qe_identity() + with pytest.raises(CollateralError, match="ISV ProdID does not match"): + validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=8, + qe_report_mrsigner=b"\xdc" * 32, + qe_report_miscselect=b"\x00\x00\x00\x00", + qe_report_attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + qe_report_isvprodid=2, # Wrong ISV ProdID + ) + + def test_isv_svn_too_low(self): + """Test validation fails when ISV SVN is too low.""" + qe_identity = self._create_qe_identity() + with pytest.raises(CollateralError, match="No matching QE TCB level"): + validate_qe_identity( + qe_identity=qe_identity, + qe_report_isv_svn=5, # Below required 8 + qe_report_mrsigner=b"\xdc" * 32, + qe_report_miscselect=b"\x00\x00\x00\x00", + qe_report_attributes=b"\x11\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + qe_report_isvprodid=1, + ) + + +class TestTdxModuleIdentity: + """Test TDX module identity validation.""" + + def _create_tcb_info_with_module_identities(self) -> TcbInfo: + """Create TCB Info with module identities for testing.""" + return TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + fmspc="00a06f000000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=17, + tdx_module=None, + tdx_module_identities=[ + TdxModuleIdentity( + id="TDX_03", + mrsigner=b"\xaa" * 48, + attributes=b"\x00\x00\x00\x00\x00\x00\x00\x00", + attributes_mask=b"\xff\xff\xff\xff\xff\xff\xff\xff", + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=3, # Minor version 3 + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ], + ), + ], + tcb_levels=[], + ) + + def test_get_module_identity(self): + """Test finding module identity by TEE_TCB_SVN.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + # Major version 3 should match TDX_03 + tee_tcb_svn = bytes([5, 3] + [0] * 14) # minor=5, major=3 + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_03" + + def test_get_module_identity_not_found(self): + """Test unknown module version raises error.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + # Major version 5 should not match any (only TDX_03 exists) + tee_tcb_svn = bytes([0, 5] + [0] * 14) # minor=0, major=5 + with pytest.raises(CollateralError, match="Unknown TDX module version TDX_05"): + get_tdx_module_identity(tcb_info, tee_tcb_svn) + + def test_validate_module_identity_success(self): + """Test successful module identity validation.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + # minor=5 (used for TCB level matching), major=3 (used to find TDX_03) + tee_tcb_svn = bytes([5, 3] + [0] * 14) + mr_signer_seam = b"\xaa" * 48 + seam_attributes = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + # Should not raise - validation succeeds + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_mrsigner_mismatch(self): + """Test module identity validation fails with MR_SIGNER_SEAM mismatch.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + tee_tcb_svn = bytes([5, 3] + [0] * 14) # minor=5, major=3 + mr_signer_seam = b"\xbb" * 48 # Wrong MR_SIGNER_SEAM + seam_attributes = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + with pytest.raises(CollateralError, match="MR_SIGNER_SEAM does not match"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_unknown_version_raises(self): + """Test validation raises when module version is unknown.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + # Major version 5 doesn't exist in module identities (only TDX_03) + tee_tcb_svn = bytes([0, 5] + [0] * 14) # minor=0, major=5 + mr_signer_seam = b"\xaa" * 48 + seam_attributes = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + with pytest.raises(CollateralError, match="Unknown TDX module version TDX_05"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_no_matching_tcb_level(self): + """Test validation raises when minor version is too low for any TCB level.""" + tcb_info = self._create_tcb_info_with_module_identities() + + # TEE_TCB_SVN[0]=minor, TEE_TCB_SVN[1]=major + # Module identity TDX_03 has TCB level with isv_svn=3 + # minor=2 < 3, so no TCB level should match + tee_tcb_svn = bytes([2, 3] + [0] * 14) # minor=2, major=3 + mr_signer_seam = b"\xaa" * 48 + seam_attributes = b"\x00\x00\x00\x00\x00\x00\x00\x00" + + with pytest.raises(CollateralError, match="Could not find a TDX Module Identity TCB Level"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + +class TestCheckCollateralFreshness: + """Test collateral freshness checking.""" + + def test_fresh_collateral_passes(self): + """Test fresh collateral passes.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + ) + + # Mock datetime to a time before sample data's nextUpdate (2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + # Should not raise (sample has tcbEvaluationDataNumber=18) + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=18) + + def test_expired_tcb_info_fails(self): + """Test expired TCB Info fails.""" + # Modify the sample to have an expired date + expired_json = SAMPLE_TCB_INFO_JSON.replace( + '"nextUpdate": "2026-01-16T06:24:56Z"', + '"nextUpdate": "2020-01-16T06:24:56Z"' + ) + tcb_info = parse_tcb_info_response(expired_json.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=expired_json.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + ) + + with pytest.raises(CollateralError, match="TCB Info has expired"): + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=18) + + def test_tcb_evaluation_data_number_threshold_passes(self): + """Test collateral passes when tcbEvaluationDataNumber meets threshold.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + ) + + # Mock datetime to a time before sample data's nextUpdate (2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + # Sample data has tcbEvaluationDataNumber=18, threshold of 18 should pass + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=18) + + # Lower threshold should also pass + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=10) + + def test_tcb_info_evaluation_data_number_below_threshold(self): + """Test failure when TCB Info tcbEvaluationDataNumber is below threshold.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + ) + + # Mock datetime to a time before sample data's nextUpdate (2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + # Sample data has tcbEvaluationDataNumber=18, threshold of 19 should fail + with pytest.raises(CollateralError, match="TCB Info tcbEvaluationDataNumber .* is below"): + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=19) + + def test_qe_identity_evaluation_data_number_below_threshold(self): + """Test failure when QE Identity tcbEvaluationDataNumber is below threshold.""" + # Create TCB Info with high tcbEvaluationDataNumber + high_eval_tcb_json = SAMPLE_TCB_INFO_JSON.replace( + '"tcbEvaluationDataNumber": 18', + '"tcbEvaluationDataNumber": 25' + ) + tcb_info = parse_tcb_info_response(high_eval_tcb_json.encode()) + # QE Identity still has tcbEvaluationDataNumber=18 + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=high_eval_tcb_json.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + ) + + # Mock datetime to a time before sample data's nextUpdate (2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + # TCB Info has 25, QE Identity has 18, threshold of 20 should fail on QE Identity + with pytest.raises(CollateralError, match="QE Identity tcbEvaluationDataNumber .* is below"): + check_collateral_freshness(collateral, min_tcb_evaluation_data_number=20) + + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + +class TestHelperFunctions: + """Test helper functions.""" + + def test_parse_datetime(self): + """Test datetime parsing.""" + result = _parse_datetime("2025-12-17T06:24:56Z") + assert result.year == 2025 + assert result.month == 12 + assert result.day == 17 + assert result.tzinfo is not None + + def test_parse_hex_bytes(self): + """Test hex string parsing.""" + result = _parse_hex_bytes("aabbccdd") + assert result == bytes([0xaa, 0xbb, 0xcc, 0xdd]) + + +class TestTcbStatus: + """Test TcbStatus enum.""" + + def test_status_values(self): + """Test all status values can be created.""" + assert TcbStatus.UP_TO_DATE == "UpToDate" + assert TcbStatus.OUT_OF_DATE == "OutOfDate" + assert TcbStatus.REVOKED == "Revoked" + assert TcbStatus.SW_HARDENING_NEEDED == "SWHardeningNeeded" + + +# ============================================================================= +# Caching Tests +# ============================================================================= + +class TestCacheHelpers: + """Test cache helper functions.""" + + def test_tcb_info_cache_path(self): + """Test TCB Info cache path generation.""" + path = _get_tcb_info_cache_path("00A06D080000") + assert "tdx_tcb_info_00a06d080000.json" in path + # Should be lowercase + path2 = _get_tcb_info_cache_path("00a06d080000") + assert path == path2 + + def test_qe_identity_cache_path(self): + """Test QE Identity cache path generation.""" + path = _get_qe_identity_cache_path() + assert "tdx_qe_identity.json" in path + + def test_is_tcb_info_fresh(self): + """Test TCB Info freshness check.""" + # Fresh: next_update is in the future + fresh_tcb_info = TdxTcbInfo( + tcb_info=TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc) - timedelta(days=1), + next_update=datetime.now(timezone.utc) + timedelta(days=29), + fmspc="00A06D080000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=1, + tdx_module=None, + tdx_module_identities=[], + tcb_levels=[], + ), + signature="", + ) + assert _is_tcb_info_fresh(fresh_tcb_info) is True + + # Stale: next_update is in the past + stale_tcb_info = TdxTcbInfo( + tcb_info=TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc) - timedelta(days=31), + next_update=datetime.now(timezone.utc) - timedelta(days=1), + fmspc="00A06D080000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=1, + tdx_module=None, + tdx_module_identities=[], + tcb_levels=[], + ), + signature="", + ) + assert _is_tcb_info_fresh(stale_tcb_info) is False + + def test_is_qe_identity_fresh(self): + """Test QE Identity freshness check.""" + # Fresh + fresh_qe = QeIdentity( + enclave_identity=EnclaveIdentity( + id="TD_QE", + version=2, + issue_date=datetime.now(timezone.utc) - timedelta(days=1), + next_update=datetime.now(timezone.utc) + timedelta(days=29), + tcb_evaluation_data_number=1, + miscselect=b"\x00" * 4, + miscselect_mask=b"\xff" * 4, + attributes=b"\x00" * 16, + attributes_mask=b"\xff" * 16, + mrsigner=b"\xaa" * 32, + isv_prod_id=1, + tcb_levels=[], + ), + signature="", + ) + assert _is_qe_identity_fresh(fresh_qe) is True + + # Stale + stale_qe = QeIdentity( + enclave_identity=EnclaveIdentity( + id="TD_QE", + version=2, + issue_date=datetime.now(timezone.utc) - timedelta(days=31), + next_update=datetime.now(timezone.utc) - timedelta(days=1), + tcb_evaluation_data_number=1, + miscselect=b"\x00" * 4, + miscselect_mask=b"\xff" * 4, + attributes=b"\x00" * 16, + attributes_mask=b"\xff" * 16, + mrsigner=b"\xaa" * 32, + isv_prod_id=1, + tcb_levels=[], + ), + signature="", + ) + assert _is_qe_identity_fresh(stale_qe) is False + + +class TestIssuerChainParsing: + """Test issuer chain header parsing.""" + + def test_parse_issuer_chain_missing_certs(self): + """Test that parsing fails with too few certificates.""" + # Single certificate is not enough (need signing + root at minimum) + single_cert_pem = """-----BEGIN CERTIFICATE----- +MIICjzCCAjSgAwIBAgIUImUM1lqdNInzg7SVUr9QGzknBqwwCgYIKoZIzj0EAwIw +aDEaMBgGA1UEAwwRSW50ZWwgU0dYIFJvb3QgQ0ExGjAYBgNVBAoMEUludGVsIENv +cnBvcmF0aW9uMRQwEgYDVQQHDAtTYW50YSBDbGFyYTELMAkGA1UECAwCQ0ExCzAJ +BgNVBAYTAlVTMB4XDTE4MDUyMTEwNDUxMFoXDTQ5MTIzMTIzNTk1OVowaDEaMBgG +A1UEAwwRSW50ZWwgU0dYIFJvb3QgQ0ExGjAYBgNVBAoMEUludGVsIENvcnBvcmF0 +aW9uMRQwEgYDVQQHDAtTYW50YSBDbGFyYTELMAkGA1UECAwCQ0ExCzAJBgNVBAYT +AlVTMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEC6nEwMDIYZOj/iPWsCzaEKi7 +1OiOSLRFhWGjbnBVJfVnkY4u3IjkDYYL0MxO4mqsyYjlBalTVYxFP2sJBK5zlKOB +uzCBuDAfBgNVHSMEGDAWgBQiZQzWWp00ifODtJVSv1AbOScGrDBSBgNVHR8ESzBJ +MEegRaBDhkFodHRwczovL2NlcnRpZmljYXRlcy50cnVzdGVkc2VydmljZXMuaW50 +ZWwuY29tL0ludGVsU0dYUm9vdENBLmRlcjAdBgNVHQ4EFgQUImUM1lqdNInzg7SV +Ur9QGzknBqwwDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQEwCgYI +KoZIzj0EAwIDSQAwRgIhAOW/5QkR+S9CiSDcNoowLuPRLsWGf/Yi7GSX94BgwTwg +AiEA4J0lrHoMs+Xo5o/sX6O9QWxHRAvZUGOdRQ7cvqRXaqI= +-----END CERTIFICATE-----""" + with pytest.raises(CollateralError, match="at least 2 certificates"): + _parse_issuer_chain_header(single_cert_pem) + + def test_parse_issuer_chain_invalid_pem(self): + """Test that parsing fails with invalid PEM data.""" + with pytest.raises(CollateralError, match="Failed to parse"): + _parse_issuer_chain_header("not-valid-pem-data") + + +class TestCollateralSignatureVerification: + """Test collateral signature verification.""" + + def test_verify_signature_missing_key(self): + """Test that verification fails when JSON key is missing.""" + json_data = b'{"other": "data"}' + with pytest.raises(CollateralError, match="does not contain"): + _verify_collateral_signature( + json_bytes=json_data, + json_key="tcbInfo", + signature_hex="00" * 64, + signing_cert=MagicMock(), + data_name="Test", + ) + + def test_verify_signature_invalid_hex(self): + """Test that verification fails with invalid signature hex.""" + json_data = b'{"tcbInfo": {}}' + with pytest.raises(CollateralError, match="not valid hex"): + _verify_collateral_signature( + json_bytes=json_data, + json_key="tcbInfo", + signature_hex="not-hex", + signing_cert=MagicMock(), + data_name="Test", + ) + + def test_verify_signature_wrong_length(self): + """Test that verification fails with wrong signature length.""" + json_data = b'{"tcbInfo": {}}' + with pytest.raises(CollateralError, match="expected 64"): + _verify_collateral_signature( + json_bytes=json_data, + json_key="tcbInfo", + signature_hex="00" * 32, # 32 bytes instead of 64 + signing_cert=MagicMock(), + data_name="Test", + ) + + +class TestFetchTcbInfoWithCache: + """Test fetch_tcb_info caching behavior.""" + + def test_cache_miss_fetches_from_network(self, tmp_path): + """Test that cache miss triggers network fetch.""" + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + with patch('tinfoil.attestation.collateral_tdx.verify_tcb_info_signature'): + mock_response = MagicMock() + mock_response.content = SAMPLE_TCB_INFO_RESPONSE + mock_response.raise_for_status = MagicMock() + mock_response.headers = {"TCB-Info-Issuer-Chain": "dummy"} + mock_get.return_value = mock_response + + # Also mock _parse_issuer_chain_header since we have dummy header + with patch('tinfoil.attestation.collateral_tdx._parse_issuer_chain_header') as mock_parse: + mock_parse.return_value = [] + + tcb_info, raw, _chain = fetch_tcb_info("00A06D080000") + + mock_get.assert_called_once() + assert tcb_info.tcb_info.id == "TDX" + + def test_cache_hit_skips_network(self, tmp_path): + """Test that fresh cache hit skips network fetch.""" + import base64 + import json as json_mod + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + # Write fresh cache in the new JSON format with issuer chain + cache_path = tmp_path / "tdx_tcb_info_00a06d080000.json" + cache_data = { + "body": base64.b64encode(SAMPLE_TCB_INFO_RESPONSE).decode("ascii"), + "issuer_chain_pem": "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----", + } + cache_path.write_text(json_mod.dumps(cache_data)) + + # Mock datetime so sample data appears fresh (nextUpdate is 2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + # Mock signature verification on cache hit + with patch('tinfoil.attestation.collateral_tdx.verify_tcb_info_signature'): + with patch('tinfoil.attestation.collateral_tdx.parse_pem_chain') as mock_pem: + mock_pem.return_value = [] + tcb_info, raw, _chain = fetch_tcb_info("00A06D080000") + + # Should not call network + mock_get.assert_not_called() + assert tcb_info.tcb_info.id == "TDX" + + def test_stale_cache_fetches_fresh(self, tmp_path): + """Test that stale cache triggers fresh fetch.""" + import base64 + import json as json_mod + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + # Write stale cache (expired next_update) in new JSON format + cache_path = tmp_path / "tdx_tcb_info_00a06d080000.json" + cache_data = { + "body": base64.b64encode(SAMPLE_STALE_TCB_INFO_RESPONSE).decode("ascii"), + "issuer_chain_pem": "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----", + } + cache_path.write_text(json_mod.dumps(cache_data)) + + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + with patch('tinfoil.attestation.collateral_tdx.verify_tcb_info_signature'): + mock_response = MagicMock() + mock_response.content = SAMPLE_TCB_INFO_RESPONSE + mock_response.raise_for_status = MagicMock() + mock_response.headers = {"TCB-Info-Issuer-Chain": "dummy"} + mock_get.return_value = mock_response + + with patch('tinfoil.attestation.collateral_tdx._parse_issuer_chain_header') as mock_parse: + mock_parse.return_value = [] + + tcb_info, raw, _chain = fetch_tcb_info("00A06D080000") + + # Should call network because cache is stale + mock_get.assert_called_once() + + +class TestFetchQeIdentityWithCache: + """Test fetch_qe_identity caching behavior.""" + + def test_cache_miss_fetches_from_network(self, tmp_path): + """Test that cache miss triggers network fetch.""" + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + with patch('tinfoil.attestation.collateral_tdx.verify_qe_identity_signature'): + mock_response = MagicMock() + mock_response.content = SAMPLE_QE_IDENTITY_RESPONSE + mock_response.raise_for_status = MagicMock() + mock_response.headers = {"SGX-Enclave-Identity-Issuer-Chain": "dummy"} + mock_get.return_value = mock_response + + with patch('tinfoil.attestation.collateral_tdx._parse_issuer_chain_header') as mock_parse: + mock_parse.return_value = [] + + qe_identity, raw, _chain = fetch_qe_identity() + + mock_get.assert_called_once() + assert qe_identity.enclave_identity.id == "TD_QE" + + def test_cache_hit_skips_network(self, tmp_path): + """Test that fresh cache hit skips network fetch.""" + import base64 + import json as json_mod + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + # Write fresh cache in the new JSON format with issuer chain + cache_path = tmp_path / "tdx_qe_identity.json" + cache_data = { + "body": base64.b64encode(SAMPLE_QE_IDENTITY_RESPONSE).decode("ascii"), + "issuer_chain_pem": "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----", + } + cache_path.write_text(json_mod.dumps(cache_data)) + + # Mock datetime so sample data appears fresh (nextUpdate is 2026-01-16) + mock_now = datetime(2026, 1, 10, tzinfo=timezone.utc) + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_dt.now.return_value = mock_now + mock_dt.fromisoformat = datetime.fromisoformat + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + # Mock signature verification on cache hit + with patch('tinfoil.attestation.collateral_tdx.verify_qe_identity_signature'): + with patch('tinfoil.attestation.collateral_tdx.parse_pem_chain') as mock_pem: + mock_pem.return_value = [] + qe_identity, raw, _chain = fetch_qe_identity() + + # Should not call network + mock_get.assert_not_called() + assert qe_identity.enclave_identity.id == "TD_QE" + + +# ============================================================================= +# CRL Tests +# ============================================================================= + +from tinfoil.attestation.collateral_tdx import ( + PckCrl, + _get_crl_cache_path, + _is_crl_fresh, + _determine_pck_ca_type, + fetch_pck_crl, + validate_certificate_revocation, +) +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec + + +class TestCrlCacheHelpers: + """Test CRL cache helper functions.""" + + def test_crl_cache_path_platform(self): + """Test CRL cache path generation for platform CA.""" + path = _get_crl_cache_path("platform") + assert "tdx_pck_crl_platform.json" in path + + def test_crl_cache_path_processor(self): + """Test CRL cache path generation for processor CA.""" + path = _get_crl_cache_path("processor") + assert "tdx_pck_crl_processor.json" in path + + def test_crl_cache_path_lowercase(self): + """Test CRL cache path is lowercase.""" + path1 = _get_crl_cache_path("Platform") + path2 = _get_crl_cache_path("platform") + assert path1 == path2 + + def test_is_crl_fresh_valid(self): + """Test CRL freshness check with valid (not expired) CRL.""" + # Create a mock CRL with next_update in the future + mock_crl = MagicMock() + mock_crl.next_update_utc = datetime.now(timezone.utc) + timedelta(days=30) + + assert _is_crl_fresh(mock_crl) is True + + def test_is_crl_fresh_expired(self): + """Test CRL freshness check with expired CRL.""" + mock_crl = MagicMock() + mock_crl.next_update_utc = datetime.now(timezone.utc) - timedelta(days=1) + + assert _is_crl_fresh(mock_crl) is False + + def test_is_crl_fresh_no_next_update(self): + """Test CRL freshness check with no next_update.""" + mock_crl = MagicMock() + mock_crl.next_update_utc = None + + assert _is_crl_fresh(mock_crl) is False + + +class TestDeterminePckCaType: + """Test PCK CA type determination from certificate issuer.""" + + def _create_mock_cert_with_issuer(self, cn: str) -> MagicMock: + """Create a mock certificate with given issuer CN.""" + mock_attr = MagicMock() + mock_attr.oid = NameOID.COMMON_NAME + mock_attr.value = cn + + mock_issuer = MagicMock() + mock_issuer.__iter__ = MagicMock(return_value=iter([mock_attr])) + + mock_cert = MagicMock() + mock_cert.issuer = mock_issuer + + return mock_cert + + def test_platform_ca(self): + """Test detecting Platform CA from issuer CN.""" + mock_cert = self._create_mock_cert_with_issuer("Intel SGX PCK Platform CA") + result = _determine_pck_ca_type(mock_cert) + assert result == "platform" + + def test_processor_ca(self): + """Test detecting Processor CA from issuer CN.""" + mock_cert = self._create_mock_cert_with_issuer("Intel SGX PCK Processor CA") + result = _determine_pck_ca_type(mock_cert) + assert result == "processor" + + def test_unknown_ca_raises_error(self): + """Test that unknown CA type raises error.""" + mock_cert = self._create_mock_cert_with_issuer("Unknown CA") + with pytest.raises(CollateralError, match="Could not determine PCK CA type"): + _determine_pck_ca_type(mock_cert) + + +class TestFetchPckCrl: + """Test fetch_pck_crl function.""" + + def test_invalid_ca_type_raises_error(self): + """Test that invalid CA type raises error.""" + with pytest.raises(CollateralError, match="Invalid CA type"): + fetch_pck_crl("invalid") + + def test_cache_hit_skips_network(self, tmp_path): + """Test that fresh cache hit skips network fetch.""" + import base64 + import json as json_mod + # Create a mock CRL + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.hazmat.primitives import hashes + from cryptography import x509 + from cryptography.x509 import CertificateRevocationListBuilder + import datetime as dt_module + + # Generate a key for signing + private_key = ec.generate_private_key(ec.SECP256R1()) + + # Build a minimal CRL + builder = CertificateRevocationListBuilder() + builder = builder.issuer_name(x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "Test CA"), + ])) + builder = builder.last_update(dt_module.datetime.now(dt_module.timezone.utc)) + builder = builder.next_update(dt_module.datetime.now(dt_module.timezone.utc) + dt_module.timedelta(days=30)) + + crl = builder.sign(private_key, hashes.SHA256()) + # Use PEM encoding - fetch_pck_crl expects PEM format + crl_pem = crl.public_bytes(serialization.Encoding.PEM) + + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + # Write fresh CRL cache in new JSON format (body is PEM bytes) + cache_path = tmp_path / "tdx_pck_crl_platform.json" + cache_data = { + "body": base64.b64encode(crl_pem).decode("ascii"), + "issuer_chain_pem": "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----", + } + cache_path.write_text(json_mod.dumps(cache_data)) + + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + # Mock signature verification on cache hit + with patch('tinfoil.attestation.collateral_tdx._verify_crl_signature'): + with patch('tinfoil.attestation.collateral_tdx.parse_pem_chain') as mock_pem: + mock_pem.return_value = [] + result = fetch_pck_crl("platform") + + # Should not call network + mock_get.assert_not_called() + assert result.ca_type == "platform" + + def test_cache_miss_fetches_from_network(self, tmp_path): + """Test that cache miss triggers network fetch.""" + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.hazmat.primitives import hashes, serialization + from cryptography import x509 + from cryptography.x509 import CertificateRevocationListBuilder + import datetime as dt_module + + # Generate keys + private_key = ec.generate_private_key(ec.SECP256R1()) + + # Build a minimal CRL + builder = CertificateRevocationListBuilder() + builder = builder.issuer_name(x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "Test CA"), + ])) + builder = builder.last_update(dt_module.datetime.now(dt_module.timezone.utc)) + builder = builder.next_update(dt_module.datetime.now(dt_module.timezone.utc) + dt_module.timedelta(days=30)) + + crl = builder.sign(private_key, hashes.SHA256()) + # Use PEM encoding - fetch_pck_crl expects PEM format from Intel PCS + crl_pem = crl.public_bytes(serialization.Encoding.PEM) + + with patch('tinfoil.attestation.collateral_tdx._TDX_CACHE_DIR', str(tmp_path)): + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + with patch('tinfoil.attestation.collateral_tdx._verify_crl_signature'): + with patch('tinfoil.attestation.collateral_tdx._parse_issuer_chain_header') as mock_parse: + mock_parse.return_value = [] + + mock_response = MagicMock() + mock_response.content = crl_pem + mock_response.raise_for_status = MagicMock() + mock_response.headers = {"SGX-PCK-CRL-Issuer-Chain": "dummy"} + mock_get.return_value = mock_response + + result = fetch_pck_crl("platform") + + mock_get.assert_called_once() + assert result.ca_type == "platform" + + +class TestValidateCertificateRevocation: + """Test certificate revocation validation.""" + + def _create_mock_crl(self, revoked_serials: list[int] = None) -> MagicMock: + """Create a mock CRL with optional revoked serials.""" + mock_crl = MagicMock() + mock_crl.next_update_utc = datetime.now(timezone.utc) + timedelta(days=30) + + def get_revoked(serial): + if revoked_serials and serial in revoked_serials: + revoked = MagicMock() + revoked.revocation_date_utc = datetime.now(timezone.utc) - timedelta(days=1) + return revoked + return None + + mock_crl.get_revoked_certificate_by_serial_number = get_revoked + return mock_crl + + def _create_mock_cert(self, serial: int) -> MagicMock: + """Create a mock certificate with given serial number.""" + mock_cert = MagicMock() + mock_cert.serial_number = serial + return mock_cert + + def test_certificate_not_revoked(self): + """Test that non-revoked certificate passes validation.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_crl = self._create_mock_crl(revoked_serials=[]) + pck_crl = PckCrl( + crl=mock_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + ) + + mock_cert = self._create_mock_cert(serial=12345) + + # Should not raise + validate_certificate_revocation(collateral, mock_cert) + + def test_certificate_revoked(self): + """Test that revoked certificate fails validation.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_crl = self._create_mock_crl(revoked_serials=[12345]) + pck_crl = PckCrl( + crl=mock_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + ) + + mock_cert = self._create_mock_cert(serial=12345) + + with pytest.raises(CollateralError, match="has been revoked"): + validate_certificate_revocation(collateral, mock_cert) + + def test_no_crl_raises_error(self): + """Test that missing CRL raises error.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=None, # No CRL + ) + + mock_cert = self._create_mock_cert(serial=12345) + + with pytest.raises(CollateralError, match="CRL not available"): + validate_certificate_revocation(collateral, mock_cert) + + def test_expired_crl_raises_error(self): + """Test that expired CRL raises error.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_crl = MagicMock() + mock_crl.next_update_utc = datetime.now(timezone.utc) - timedelta(days=1) # Expired + + pck_crl = PckCrl( + crl=mock_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) - timedelta(days=1), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + ) + + mock_cert = self._create_mock_cert(serial=12345) + + with pytest.raises(CollateralError, match="CRL has expired"): + validate_certificate_revocation(collateral, mock_cert) + + def test_intermediate_cert_not_revoked(self): + """Test that non-revoked intermediate CA certificate passes validation.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_pck_crl = self._create_mock_crl(revoked_serials=[]) + pck_crl = PckCrl( + crl=mock_pck_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + mock_root_crl = self._create_mock_crl(revoked_serials=[]) + from tinfoil.attestation.collateral_tdx import RootCrl + root_crl = RootCrl( + crl=mock_root_crl, + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + root_crl=root_crl, + ) + + mock_pck_cert = self._create_mock_cert(serial=12345) + mock_intermediate_cert = self._create_mock_cert(serial=67890) + + # Should not raise + validate_certificate_revocation(collateral, mock_pck_cert, mock_intermediate_cert) + + def test_intermediate_cert_revoked(self): + """Test that revoked intermediate CA certificate fails validation.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_pck_crl = self._create_mock_crl(revoked_serials=[]) + pck_crl = PckCrl( + crl=mock_pck_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + # Root CRL has the intermediate CA's serial as revoked + mock_root_crl = self._create_mock_crl(revoked_serials=[67890]) + from tinfoil.attestation.collateral_tdx import RootCrl + root_crl = RootCrl( + crl=mock_root_crl, + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + root_crl=root_crl, + ) + + mock_pck_cert = self._create_mock_cert(serial=12345) + mock_intermediate_cert = self._create_mock_cert(serial=67890) + + with pytest.raises(CollateralError, match="Intermediate CA certificate has been revoked"): + validate_certificate_revocation(collateral, mock_pck_cert, mock_intermediate_cert) + + def test_no_root_crl_with_intermediate_raises_error(self): + """Test that missing root CRL raises error when checking intermediate.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_pck_crl = self._create_mock_crl(revoked_serials=[]) + pck_crl = PckCrl( + crl=mock_pck_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + root_crl=None, # No root CRL + ) + + mock_pck_cert = self._create_mock_cert(serial=12345) + mock_intermediate_cert = self._create_mock_cert(serial=67890) + + with pytest.raises(CollateralError, match="Root CRL not available"): + validate_certificate_revocation(collateral, mock_pck_cert, mock_intermediate_cert) + + def test_expired_root_crl_raises_error(self): + """Test that expired root CRL raises error when checking intermediate.""" + tcb_info = parse_tcb_info_response(SAMPLE_TCB_INFO_JSON.encode()) + qe_identity = parse_qe_identity_response(SAMPLE_QE_IDENTITY_JSON.encode()) + + mock_pck_crl = self._create_mock_crl(revoked_serials=[]) + pck_crl = PckCrl( + crl=mock_pck_crl, + ca_type="platform", + next_update=datetime.now(timezone.utc) + timedelta(days=30), + ) + + mock_root_crl = MagicMock() + mock_root_crl.next_update_utc = datetime.now(timezone.utc) - timedelta(days=1) # Expired + from tinfoil.attestation.collateral_tdx import RootCrl + root_crl = RootCrl( + crl=mock_root_crl, + next_update=datetime.now(timezone.utc) - timedelta(days=1), + ) + + collateral = TdxCollateral( + tcb_info=tcb_info, + qe_identity=qe_identity, + tcb_info_raw=SAMPLE_TCB_INFO_JSON.encode(), + qe_identity_raw=SAMPLE_QE_IDENTITY_JSON.encode(), + pck_crl=pck_crl, + root_crl=root_crl, + ) + + mock_pck_cert = self._create_mock_cert(serial=12345) + mock_intermediate_cert = self._create_mock_cert(serial=67890) + + with pytest.raises(CollateralError, match="Root CA CRL has expired"): + validate_certificate_revocation(collateral, mock_pck_cert, mock_intermediate_cert) + + +# ============================================================================= +# TCB Evaluation Data Numbers Tests +# ============================================================================= + +from tinfoil.attestation.collateral_tdx import ( + TcbEvalNumber, + TcbEvaluationDataNumbers, + _parse_tcb_eval_numbers_response, + fetch_tcb_evaluation_data_numbers, + calculate_min_tcb_evaluation_data_number, +) + + +SAMPLE_TCB_EVAL_NUMBERS_RESPONSE = b'''{ + "tcbEvaluationDataNumbers": { + "id": "TDX", + "version": 1, + "issueDate": "2026-01-15T19:17:02Z", + "nextUpdate": "2026-02-14T19:17:02Z", + "tcbEvalNumbers": [ + {"tcbEvaluationDataNumber": 20, "tcbRecoveryEventDate": "2025-08-12T00:00:00Z", "tcbDate": "2025-08-13T00:00:00Z"}, + {"tcbEvaluationDataNumber": 19, "tcbRecoveryEventDate": "2025-05-13T00:00:00Z", "tcbDate": "2025-05-14T00:00:00Z"}, + {"tcbEvaluationDataNumber": 18, "tcbRecoveryEventDate": "2024-11-12T00:00:00Z", "tcbDate": "2024-11-13T00:00:00Z"}, + {"tcbEvaluationDataNumber": 17, "tcbRecoveryEventDate": "2024-03-12T00:00:00Z", "tcbDate": "2024-03-13T00:00:00Z"}, + {"tcbEvaluationDataNumber": 16, "tcbRecoveryEventDate": "2023-08-08T00:00:00Z", "tcbDate": "2023-08-09T00:00:00Z"} + ] + }, + "signature": "dummy" +}''' + + +class TestParseTcbEvalNumbersResponse: + """Test parsing of tcbevaluationdatanumbers response.""" + + def test_parse_valid_response(self): + """Test parsing a valid response.""" + result = _parse_tcb_eval_numbers_response(SAMPLE_TCB_EVAL_NUMBERS_RESPONSE) + + assert result.id == "TDX" + assert result.version == 1 + assert len(result.tcb_eval_numbers) == 5 + assert result.tcb_eval_numbers[0].tcb_evaluation_data_number == 20 + assert result.tcb_eval_numbers[4].tcb_evaluation_data_number == 16 + + def test_parse_dates(self): + """Test that dates are parsed correctly.""" + result = _parse_tcb_eval_numbers_response(SAMPLE_TCB_EVAL_NUMBERS_RESPONSE) + + # Check first entry + assert result.tcb_eval_numbers[0].tcb_recovery_event_date.year == 2025 + assert result.tcb_eval_numbers[0].tcb_recovery_event_date.month == 8 + assert result.tcb_eval_numbers[0].tcb_recovery_event_date.day == 12 + + def test_parse_invalid_json(self): + """Test that invalid JSON raises error.""" + with pytest.raises(CollateralError, match="Failed to parse"): + _parse_tcb_eval_numbers_response(b"not json") + + +class TestFetchTcbEvaluationDataNumbers: + """Test fetching TCB evaluation data numbers.""" + + def test_fetch_success(self): + """Test successful fetch with signature verification mocked out.""" + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get, \ + patch('tinfoil.attestation.collateral_tdx._parse_issuer_chain_header') as mock_parse, \ + patch('tinfoil.attestation.collateral_tdx.verify_tcb_eval_numbers_signature'): + mock_response = MagicMock() + mock_response.content = SAMPLE_TCB_EVAL_NUMBERS_RESPONSE + mock_response.raise_for_status = MagicMock() + mock_response.headers = {"TCB-Evaluation-Data-Numbers-Issuer-Chain": "dummy-chain"} + mock_get.return_value = mock_response + mock_parse.return_value = [] + + result = fetch_tcb_evaluation_data_numbers() + + mock_get.assert_called_once() + assert result.id == "TDX" + + def test_fetch_missing_issuer_chain_header(self): + """Test that missing issuer chain header raises error.""" + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + mock_response = MagicMock() + mock_response.content = SAMPLE_TCB_EVAL_NUMBERS_RESPONSE + mock_response.raise_for_status = MagicMock() + mock_response.headers = {} + mock_get.return_value = mock_response + + with pytest.raises(CollateralError, match="missing TCB-Evaluation-Data-Numbers-Issuer-Chain"): + fetch_tcb_evaluation_data_numbers() + + def test_fetch_failure(self): + """Test fetch failure raises error.""" + import requests as req + with patch('tinfoil.attestation.collateral_tdx.requests.get') as mock_get: + mock_get.side_effect = req.RequestException("Network error") + + with pytest.raises(CollateralError, match="HTTP GET"): + fetch_tcb_evaluation_data_numbers() + + +class TestCalculateMinTcbEvaluationDataNumber: + """Test calculate_min_tcb_evaluation_data_number function.""" + + def test_calculate_with_recent_cutoff(self): + """Test calculation finds correct minimum with recent cutoff.""" + with patch('tinfoil.attestation.collateral_tdx.fetch_tcb_evaluation_data_numbers') as mock_fetch: + mock_fetch.return_value = _parse_tcb_eval_numbers_response(SAMPLE_TCB_EVAL_NUMBERS_RESPONSE) + + # With 365-day cutoff from 2026-01-15, numbers from 2025-01-15 or later are acceptable + # Number 19 (2025-05-13) should be the minimum acceptable + # Number 18 (2024-11-12) is older than 1 year + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_now = datetime(2026, 1, 15, tzinfo=timezone.utc) + mock_dt.now.return_value = mock_now + mock_dt.side_effect = lambda *args, **kw: datetime(*args, **kw) + + result = calculate_min_tcb_evaluation_data_number(max_age_days=365) + + # Should return 19 since 18 is older than 1 year + assert result == 19 + + def test_calculate_with_shorter_cutoff(self): + """Test calculation with shorter max age.""" + with patch('tinfoil.attestation.collateral_tdx.fetch_tcb_evaluation_data_numbers') as mock_fetch: + mock_fetch.return_value = _parse_tcb_eval_numbers_response(SAMPLE_TCB_EVAL_NUMBERS_RESPONSE) + + # With 180-day cutoff from 2026-01-15, cutoff is ~2025-07-19 + # Number 20 (2025-08-12) should be acceptable + # Number 19 (2025-05-13) is older than 180 days + with patch('tinfoil.attestation.collateral_tdx.datetime') as mock_dt: + mock_now = datetime(2026, 1, 15, tzinfo=timezone.utc) + mock_dt.now.return_value = mock_now + mock_dt.side_effect = lambda *args, **kw: datetime(*args, **kw) + + result = calculate_min_tcb_evaluation_data_number(max_age_days=180) + + assert result == 20 + + def test_calculate_empty_numbers_raises(self): + """Test that empty numbers list raises error.""" + with patch('tinfoil.attestation.collateral_tdx.fetch_tcb_evaluation_data_numbers') as mock_fetch: + mock_fetch.return_value = TcbEvaluationDataNumbers( + id="TDX", + version=1, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + tcb_eval_numbers=[], + signature="dummy", + ) + + with pytest.raises(CollateralError, match="No TCB evaluation data numbers found"): + calculate_min_tcb_evaluation_data_number() + + def test_calculate_all_too_old_raises(self): + """Test that all numbers too old raises error.""" + # Create response with only old numbers + old_numbers_response = b'''{ + "tcbEvaluationDataNumbers": { + "id": "TDX", + "version": 1, + "issueDate": "2026-01-15T19:17:02Z", + "nextUpdate": "2026-02-14T19:17:02Z", + "tcbEvalNumbers": [ + {"tcbEvaluationDataNumber": 10, "tcbRecoveryEventDate": "2020-01-01T00:00:00Z", "tcbDate": "2020-01-02T00:00:00Z"} + ] + }, + "signature": "dummy" + }''' + + with patch('tinfoil.attestation.collateral_tdx.fetch_tcb_evaluation_data_numbers') as mock_fetch: + mock_fetch.return_value = _parse_tcb_eval_numbers_response(old_numbers_response) + + with pytest.raises(CollateralError, match="older than 365 days"): + calculate_min_tcb_evaluation_data_number() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_pck_extensions.py b/tests/test_pck_extensions.py new file mode 100644 index 0000000..2f05627 --- /dev/null +++ b/tests/test_pck_extensions.py @@ -0,0 +1,144 @@ +""" +Unit tests for PCK certificate extension parsing. +""" + +import pytest + +from tinfoil.attestation.pck_extensions import ( + PckExtensions, + PckCertTCB, + PckExtensionError, + OID_SGX_EXTENSION, + OID_FMSPC, + OID_PCEID, + OID_PPID, + OID_TCB, + FMSPC_SIZE, + PCEID_SIZE, + PPID_SIZE, + TCB_COMPONENTS_COUNT, +) + + +# ============================================================================= +# Test OID Constants +# ============================================================================= + +class TestOidConstants: + """Test Intel SGX OID constants.""" + + def test_sgx_extension_oid(self): + """Test SGX extension OID value.""" + assert OID_SGX_EXTENSION.dotted_string == "1.2.840.113741.1.13.1" + + def test_fmspc_oid(self): + """Test FMSPC OID value.""" + assert OID_FMSPC.dotted_string == "1.2.840.113741.1.13.1.4" + + def test_pceid_oid(self): + """Test PCEID OID value.""" + assert OID_PCEID.dotted_string == "1.2.840.113741.1.13.1.3" + + def test_ppid_oid(self): + """Test PPID OID value.""" + assert OID_PPID.dotted_string == "1.2.840.113741.1.13.1.1" + + def test_tcb_oid(self): + """Test TCB OID value.""" + assert OID_TCB.dotted_string == "1.2.840.113741.1.13.1.2" + + +# ============================================================================= +# Test Size Constants +# ============================================================================= + +class TestSizeConstants: + """Test size constants.""" + + def test_fmspc_size(self): + """Test FMSPC is 6 bytes.""" + assert FMSPC_SIZE == 6 + + def test_pceid_size(self): + """Test PCEID is 2 bytes.""" + assert PCEID_SIZE == 2 + + def test_ppid_size(self): + """Test PPID is 16 bytes.""" + assert PPID_SIZE == 16 + + def test_tcb_components_count(self): + """Test TCB has 16 components.""" + assert TCB_COMPONENTS_COUNT == 16 + + +# ============================================================================= +# Test Data Classes +# ============================================================================= + +class TestPckCertTCB: + """Test PckCertTCB dataclass.""" + + def test_create_tcb(self): + """Test creating PckCertTCB.""" + tcb = PckCertTCB( + pce_svn=13, + cpu_svn=bytes(16), + tcb_components=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ) + assert tcb.pce_svn == 13 + assert len(tcb.cpu_svn) == 16 + assert len(tcb.tcb_components) == 16 + + def test_tcb_str(self): + """Test string representation.""" + tcb = PckCertTCB( + pce_svn=13, + cpu_svn=bytes([0xAB] * 16), + tcb_components=[0] * 16, + ) + s = str(tcb) + assert "pce_svn=13" in s + assert "abab" in s.lower() + + +class TestPckExtensions: + """Test PckExtensions dataclass.""" + + def test_create_extensions(self): + """Test creating PckExtensions.""" + tcb = PckCertTCB( + pce_svn=13, + cpu_svn=bytes(16), + tcb_components=[0] * 16, + ) + ext = PckExtensions( + ppid="00112233445566778899aabbccddeeff", + tcb=tcb, + pceid="0000", + fmspc="00606a000000", + ) + assert ext.fmspc == "00606a000000" + assert ext.pceid == "0000" + assert len(ext.ppid) == 32 # 16 bytes hex + + def test_extensions_str(self): + """Test string representation.""" + tcb = PckCertTCB( + pce_svn=13, + cpu_svn=bytes(16), + tcb_components=[0] * 16, + ) + ext = PckExtensions( + ppid="00112233445566778899aabbccddeeff", + tcb=tcb, + pceid="0000", + fmspc="00606a000000", + ) + s = str(ext) + assert "fmspc=00606a000000" in s + assert "pceid=0000" in s + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_sev_validation.py b/tests/test_sev_validation.py index 9650f56..6ae1059 100644 --- a/tests/test_sev_validation.py +++ b/tests/test_sev_validation.py @@ -1,12 +1,12 @@ import pytest -from tinfoil.attestation.validate import ( +from tinfoil.attestation.validate_sev import ( ValidationOptions, validate_report, _validate_policy, _compare_policy_versions, _validate_platform_info ) -from tinfoil.attestation.abi_sevsnp import ( +from tinfoil.attestation.abi_sev import ( Report, SnpPolicy, SnpPlatformInfo, @@ -14,7 +14,7 @@ SignerInfo, ReportSigner ) -from tinfoil.attestation.verify import CertificateChain +from tinfoil.attestation.verify_sev import CertificateChain class TestValidationOptions: diff --git a/tests/test_tdx_abi.py b/tests/test_tdx_abi.py new file mode 100644 index 0000000..002a229 --- /dev/null +++ b/tests/test_tdx_abi.py @@ -0,0 +1,548 @@ +""" +Unit tests for TDX quote parsing (abi_tdx.py). +""" + +import pytest +import struct + +from tinfoil.attestation.abi_tdx import ( + # Constants + QUOTE_MIN_SIZE, + QUOTE_VERSION_V4, + QUOTE_VERSION_V5, + TEE_TDX, + ATTESTATION_KEY_TYPE_ECDSA_P256, + INTEL_QE_VENDOR_ID, + HEADER_SIZE, + TD_QUOTE_BODY_SIZE, + CERT_DATA_TYPE_PCK_CERT_CHAIN, + CERT_DATA_TYPE_QE_REPORT, + QE_REPORT_SIZE, + # Dataclasses + TdxHeader, + TdQuoteBody, + QeReport, + SignedData, + QuoteV4, + CertificationData, + PckCertChainData, + QeReportCertificationData, + # Parsing functions + parse_quote, + TdxQuoteParseError, + _parse_header, + _parse_td_quote_body, + _parse_qe_report, +) + + +# ============================================================================= +# Test Fixtures - Synthetic Quote Generation +# ============================================================================= + +def build_header( + version: int = QUOTE_VERSION_V4, + attestation_key_type: int = ATTESTATION_KEY_TYPE_ECDSA_P256, + tee_type: int = TEE_TDX, + reserved: bytes = b'\x00\x00\x00\x00', + qe_vendor_id: bytes = INTEL_QE_VENDOR_ID, + user_data: bytes = b'\x00' * 20, +) -> bytes: + """Build a synthetic TDX quote header (48 bytes). + + Note: Bytes 8-11 are reserved. Some older specs labeled these as + QE_SVN/PCE_SVN, but they are always zero in actual quotes. The real + SVN values come from PCK certificate extensions and QE Report. + """ + header = b'' + header += struct.pack(' bytes: + """Build a synthetic TD quote body (584 bytes).""" + body = b'' + body += tee_tcb_svn[:16].ljust(16, b'\x00') + body += mr_seam[:48].ljust(48, b'\x00') + body += mr_signer_seam[:48].ljust(48, b'\x00') + body += seam_attributes[:8].ljust(8, b'\x00') + body += td_attributes[:8].ljust(8, b'\x00') + body += xfam[:8].ljust(8, b'\x00') + body += mr_td[:48].ljust(48, b'\x00') + body += mr_config_id[:48].ljust(48, b'\x00') + body += mr_owner[:48].ljust(48, b'\x00') + body += mr_owner_config[:48].ljust(48, b'\x00') + body += rtmr0[:48].ljust(48, b'\x00') + body += rtmr1[:48].ljust(48, b'\x00') + body += rtmr2[:48].ljust(48, b'\x00') + body += rtmr3[:48].ljust(48, b'\x00') + body += report_data[:64].ljust(64, b'\x00') + assert len(body) == TD_QUOTE_BODY_SIZE + return body + + +def build_qe_report( + cpu_svn: bytes = b'\x00' * 16, + misc_select: int = 0, + attributes: bytes = b'\x00' * 16, + mr_enclave: bytes = b'\xee' * 32, + mr_signer: bytes = b'\xff' * 32, + isv_prod_id: int = 1, + isv_svn: int = 2, + report_data: bytes = b'\x00' * 64, +) -> bytes: + """Build a synthetic QE report (384 bytes).""" + report = b'' + report += cpu_svn[:16].ljust(16, b'\x00') # 0x00-0x10 + report += struct.pack(' bytes: + """Build PCK cert chain data (type 5).""" + return struct.pack(' bytes: + """Build signed data with type 6 (QE report certification data).""" + if qe_report is None: + qe_report = build_qe_report() + + # QE report certification data content + qe_cert_content = qe_report + qe_report_signature + struct.pack(' bytes: + """Build a complete synthetic TDX quote.""" + if header is None: + header = build_header() + if body is None: + body = build_td_quote_body() + if signed_data is None: + signed_data = build_signed_data() + + return header + body + struct.pack(' bytes: + """Create a mock in-toto payload with multiplatform predicate.""" + payload = { + "predicateType": PredicateType.SNP_TDX_MULTIPLATFORM_v1.value, + "predicate": { + "snp_measurement": snp_measurement, + "tdx_measurement": { + "rtmr1": rtmr1, + "rtmr2": rtmr2, + }, + }, + "subject": [{"digest": {"sha256": digest}}], + } + return json.dumps(payload).encode() + + def test_parse_multiplatform_extracts_all_registers(self): + """Test that multiplatform parsing extracts snp_measurement, rtmr1, rtmr2.""" + from tinfoil.sigstore import verify_attestation + + payload = self._create_mock_bundle_payload( + snp_measurement=SAMPLE_SNP_MEASUREMENT, + rtmr1=SAMPLE_RTMR1, + rtmr2=SAMPLE_RTMR2, + digest="test_digest_abc123", + ) + + with patch('tinfoil.sigstore.Verifier') as mock_verifier_cls: + mock_verifier = MagicMock() + mock_verifier_cls.production.return_value = mock_verifier + mock_verifier.verify_dsse.return_value = ( + 'application/vnd.in-toto+json', + payload, + ) + + with patch('tinfoil.sigstore.Bundle'): + result = verify_attestation( + bundle_json=b'{}', + digest="test_digest_abc123", + repo="test/repo", + ) + + assert result.type == PredicateType.SNP_TDX_MULTIPLATFORM_v1 + assert len(result.registers) == 3 + assert result.registers[0] == SAMPLE_SNP_MEASUREMENT + assert result.registers[1] == SAMPLE_RTMR1 + assert result.registers[2] == SAMPLE_RTMR2 + + def test_parse_multiplatform_missing_snp_measurement_fails(self): + """Test that missing snp_measurement raises ValueError.""" + from tinfoil.sigstore import verify_attestation + + payload = { + "predicateType": PredicateType.SNP_TDX_MULTIPLATFORM_v1.value, + "predicate": { + # snp_measurement missing + "tdx_measurement": { + "rtmr1": SAMPLE_RTMR1, + "rtmr2": SAMPLE_RTMR2, + }, + }, + "subject": [{"digest": {"sha256": "test_digest"}}], + } + payload_bytes = json.dumps(payload).encode() + + with patch('tinfoil.sigstore.Verifier') as mock_verifier_cls: + mock_verifier = MagicMock() + mock_verifier_cls.production.return_value = mock_verifier + mock_verifier.verify_dsse.return_value = ( + 'application/vnd.in-toto+json', + payload_bytes, + ) + + with patch('tinfoil.sigstore.Bundle'): + with pytest.raises(ValueError, match="no snp_measurement"): + verify_attestation(b'{}', "test_digest", "test/repo") + + def test_parse_multiplatform_missing_tdx_measurement_fails(self): + """Test that missing tdx_measurement struct raises ValueError.""" + from tinfoil.sigstore import verify_attestation + + payload = { + "predicateType": PredicateType.SNP_TDX_MULTIPLATFORM_v1.value, + "predicate": { + "snp_measurement": SAMPLE_SNP_MEASUREMENT, + # tdx_measurement missing + }, + "subject": [{"digest": {"sha256": "test_digest"}}], + } + payload_bytes = json.dumps(payload).encode() + + with patch('tinfoil.sigstore.Verifier') as mock_verifier_cls: + mock_verifier = MagicMock() + mock_verifier_cls.production.return_value = mock_verifier + mock_verifier.verify_dsse.return_value = ( + 'application/vnd.in-toto+json', + payload_bytes, + ) + + with patch('tinfoil.sigstore.Bundle'): + with pytest.raises(ValueError, match="no tdx_measurement"): + verify_attestation(b'{}', "test_digest", "test/repo") + + def test_parse_multiplatform_missing_rtmr1_fails(self): + """Test that missing rtmr1 in tdx_measurement raises ValueError.""" + from tinfoil.sigstore import verify_attestation + + payload = { + "predicateType": PredicateType.SNP_TDX_MULTIPLATFORM_v1.value, + "predicate": { + "snp_measurement": SAMPLE_SNP_MEASUREMENT, + "tdx_measurement": { + # rtmr1 missing + "rtmr2": SAMPLE_RTMR2, + }, + }, + "subject": [{"digest": {"sha256": "test_digest"}}], + } + payload_bytes = json.dumps(payload).encode() + + with patch('tinfoil.sigstore.Verifier') as mock_verifier_cls: + mock_verifier = MagicMock() + mock_verifier_cls.production.return_value = mock_verifier + mock_verifier.verify_dsse.return_value = ( + 'application/vnd.in-toto+json', + payload_bytes, + ) + + with patch('tinfoil.sigstore.Bundle'): + with pytest.raises(ValueError, match="missing rtmr1 or rtmr2"): + verify_attestation(b'{}', "test_digest", "test/repo") + + def test_parse_multiplatform_missing_rtmr2_fails(self): + """Test that missing rtmr2 in tdx_measurement raises ValueError.""" + from tinfoil.sigstore import verify_attestation + + payload = { + "predicateType": PredicateType.SNP_TDX_MULTIPLATFORM_v1.value, + "predicate": { + "snp_measurement": SAMPLE_SNP_MEASUREMENT, + "tdx_measurement": { + "rtmr1": SAMPLE_RTMR1, + # rtmr2 missing + }, + }, + "subject": [{"digest": {"sha256": "test_digest"}}], + } + payload_bytes = json.dumps(payload).encode() + + with patch('tinfoil.sigstore.Verifier') as mock_verifier_cls: + mock_verifier = MagicMock() + mock_verifier_cls.production.return_value = mock_verifier + mock_verifier.verify_dsse.return_value = ( + 'application/vnd.in-toto+json', + payload_bytes, + ) + + with patch('tinfoil.sigstore.Bundle'): + with pytest.raises(ValueError, match="missing rtmr1 or rtmr2"): + verify_attestation(b'{}', "test_digest", "test/repo") + + +# ============================================================================= +# RTMR3-Zero Enforcement Tests +# ============================================================================= + +class TestRtmr3ZeroEnforcement: + """Test RTMR3-zero enforcement in measurement comparison.""" + + def test_rtmr3_zero_constant_is_correct(self): + """Test that RTMR3_ZERO constant is 96 hex zeros (48 bytes).""" + assert len(RTMR3_ZERO) == 96 + assert RTMR3_ZERO == "0" * 96 + + def test_multiplatform_vs_tdx_with_zero_rtmr3_passes(self): + """Test comparison passes when TDX RTMR3 is zeros.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, # MRTD (index 0) + SAMPLE_RTMR0, # RTMR0 (index 1) + SAMPLE_RTMR1, # RTMR1 (index 2) - must match + SAMPLE_RTMR2, # RTMR2 (index 3) - must match + SAMPLE_RTMR3_ZEROS, # RTMR3 (index 4) - must be zeros + ], + ) + + # Should not raise + multiplatform.assert_equal(tdx) + + def test_multiplatform_vs_tdx_with_nonzero_rtmr3_fails(self): + """Test comparison fails when TDX RTMR3 is not zeros.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, + SAMPLE_RTMR0, + SAMPLE_RTMR1, + SAMPLE_RTMR2, + SAMPLE_RTMR3_NONZERO, # RTMR3 is NOT zeros + ], + ) + + with pytest.raises(Rtmr3NotZeroError, match="RTMR3 must be zeros"): + multiplatform.assert_equal(tdx) + + def test_multiplatform_vs_tdx_rtmr1_mismatch_fails(self): + """Test comparison fails when RTMR1 doesn't match.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, + SAMPLE_RTMR0, + "wrong_rtmr1" + "0" * 80, # Wrong RTMR1 + SAMPLE_RTMR2, + SAMPLE_RTMR3_ZEROS, + ], + ) + + with pytest.raises(MeasurementMismatchError): + multiplatform.assert_equal(tdx) + + def test_multiplatform_vs_tdx_rtmr2_mismatch_fails(self): + """Test comparison fails when RTMR2 doesn't match.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, + SAMPLE_RTMR0, + SAMPLE_RTMR1, + "wrong_rtmr2" + "0" * 80, # Wrong RTMR2 + SAMPLE_RTMR3_ZEROS, + ], + ) + + with pytest.raises(MeasurementMismatchError): + multiplatform.assert_equal(tdx) + + def test_reverse_comparison_tdx_vs_multiplatform(self): + """Test reverse comparison (TDX.assert_equal(multiplatform)) also works.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, + SAMPLE_RTMR0, + SAMPLE_RTMR1, + SAMPLE_RTMR2, + SAMPLE_RTMR3_ZEROS, + ], + ) + + # Reverse comparison should also work + tdx.assert_equal(multiplatform) + + def test_reverse_comparison_with_nonzero_rtmr3_fails(self): + """Test reverse comparison also enforces RTMR3-zero.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=[ + SAMPLE_MRTD, + SAMPLE_RTMR0, + SAMPLE_RTMR1, + SAMPLE_RTMR2, + SAMPLE_RTMR3_NONZERO, + ], + ) + + with pytest.raises(Rtmr3NotZeroError): + tdx.assert_equal(multiplatform) + + def test_multiplatform_vs_snp_compares_snp_measurement(self): + """Test multiplatform vs SNP compares the SNP measurement register.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + snp = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=[SAMPLE_SNP_MEASUREMENT], + ) + + # Should not raise - SNP measurement matches + multiplatform.assert_equal(snp) + + def test_multiplatform_vs_snp_mismatch_fails(self): + """Test multiplatform vs SNP fails when SNP measurement doesn't match.""" + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=[SAMPLE_SNP_MEASUREMENT, SAMPLE_RTMR1, SAMPLE_RTMR2], + ) + + snp = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["wrong_snp_" + "0" * 86], # Wrong SNP measurement + ) + + with pytest.raises(MeasurementMismatchError): + multiplatform.assert_equal(snp) + + +# ============================================================================= +# Module Identity Matching with Real TEE_TCB_SVN Values +# ============================================================================= + +class TestModuleIdentityMatchingWithRealTeeTcbSvn: + """Test module-identity matching with realistic tee_tcb_svn byte patterns.""" + + def _create_tcb_info_with_modules( + self, + module_ids: list[str], + module_mrsigner: bytes = b"\xaa" * 48, + ) -> TcbInfo: + """Create TCB Info with specified module identities.""" + module_identities = [] + for module_id in module_ids: + module_identities.append( + TdxModuleIdentity( + id=module_id, + mrsigner=module_mrsigner, + attributes=b"\x00" * 8, + attributes_mask=b"\xff" * 8, + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=3, # Minimum minor version + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ], + ) + ) + + return TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + fmspc="00a06f000000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=17, + tdx_module=None, + tdx_module_identities=module_identities, + tcb_levels=[], + ) + + # --- TEE_TCB_SVN byte layout tests --- + # TEE_TCB_SVN is 16 bytes where: + # - Byte 0 (index 0) = minor SVN (used for module TCB level matching) + # - Byte 1 (index 1) = major SVN (used to derive module ID: TDX_{major:02d}) + # - Bytes 2-15 = other TCB components + + def test_tee_tcb_svn_major_version_3_matches_tdx_03(self): + """Test TEE_TCB_SVN with major=3 matches TDX_03 module.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_01", "TDX_03", "TDX_05"]) + + # TEE_TCB_SVN: minor=5, major=3 (TDX module version 3) + tee_tcb_svn = bytes([5, 3] + [0] * 14) + + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_03" + + def test_tee_tcb_svn_major_version_1_matches_tdx_01(self): + """Test TEE_TCB_SVN with major=1 matches TDX_01 module.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_01", "TDX_03"]) + + # TEE_TCB_SVN: minor=2, major=1 (TDX module version 1) + tee_tcb_svn = bytes([2, 1] + [0] * 14) + + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_01" + + def test_tee_tcb_svn_major_version_5_matches_tdx_05(self): + """Test TEE_TCB_SVN with major=5 matches TDX_05 module.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_03", "TDX_05"]) + + # TEE_TCB_SVN: minor=0, major=5 (TDX module version 5) + tee_tcb_svn = bytes([0, 5] + [0] * 14) + + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_05" + + def test_tee_tcb_svn_unknown_major_version_raises(self): + """Test TEE_TCB_SVN with unknown major version raises error.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_01", "TDX_03"]) + + # TEE_TCB_SVN: minor=0, major=99 (no TDX_99 module exists) + tee_tcb_svn = bytes([0, 99] + [0] * 14) + + with pytest.raises(CollateralError, match="Unknown TDX module version TDX_99"): + get_tdx_module_identity(tcb_info, tee_tcb_svn) + + def test_tee_tcb_svn_major_version_0_matches_tdx_00(self): + """Test TEE_TCB_SVN with major=0 matches TDX_00 if present.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_00", "TDX_01"]) + + # TEE_TCB_SVN: minor=1, major=0 (TDX module version 0) + tee_tcb_svn = bytes([1, 0] + [0] * 14) + + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_00" + + def test_tee_tcb_svn_too_short_raises(self): + """Test TEE_TCB_SVN with < 2 bytes raises error.""" + tcb_info = self._create_tcb_info_with_modules(["TDX_03"]) + + # Only 1 byte - not enough to extract major version + tee_tcb_svn = bytes([5]) + + with pytest.raises(CollateralError, match="TEE_TCB_SVN is too short"): + get_tdx_module_identity(tcb_info, tee_tcb_svn) + + def test_validate_module_identity_with_real_svn_values(self): + """Test full module identity validation with realistic SVN values.""" + module_mrsigner = b"\xaa" * 48 + tcb_info = self._create_tcb_info_with_modules( + ["TDX_03"], + module_mrsigner=module_mrsigner, + ) + + # Realistic TEE_TCB_SVN: minor=5, major=3 + # This should match TDX_03 and pass TCB level check (minor >= 3) + tee_tcb_svn = bytes([5, 3] + [0] * 14) + mr_signer_seam = module_mrsigner + seam_attributes = b"\x00" * 8 + + # Should not raise - validation succeeds + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_minor_svn_too_low_raises_error(self): + """Test that minor SVN below threshold raises CollateralError.""" + module_mrsigner = b"\xaa" * 48 + + tcb_info = TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + fmspc="00a06f000000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=17, + tdx_module=None, + tdx_module_identities=[ + TdxModuleIdentity( + id="TDX_03", + mrsigner=module_mrsigner, + attributes=b"\x00" * 8, + attributes_mask=b"\xff" * 8, + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=5, # Requires minor >= 5 + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ], + ) + ], + tcb_levels=[], + ) + + # TEE_TCB_SVN: minor=2, major=3 - minor too low (< 5) + tee_tcb_svn = bytes([2, 3] + [0] * 14) + mr_signer_seam = module_mrsigner + seam_attributes = b"\x00" * 8 + + # No matching level found raises CollateralError + with pytest.raises(CollateralError, match="Could not find a TDX Module Identity TCB Level"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_mrsigner_mismatch_fails(self): + """Test that MR_SIGNER_SEAM mismatch raises error.""" + expected_mrsigner = b"\xaa" * 48 + tcb_info = self._create_tcb_info_with_modules( + ["TDX_03"], + module_mrsigner=expected_mrsigner, + ) + + # TEE_TCB_SVN: minor=5, major=3 + tee_tcb_svn = bytes([5, 3] + [0] * 14) + mr_signer_seam = b"\xbb" * 48 # Wrong! + seam_attributes = b"\x00" * 8 + + with pytest.raises(CollateralError, match="MR_SIGNER_SEAM does not match"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_attributes_mismatch_fails(self): + """Test that SEAM_ATTRIBUTES mismatch under mask raises error.""" + module_mrsigner = b"\xaa" * 48 + + tcb_info = TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + fmspc="00a06f000000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=17, + tdx_module=None, + tdx_module_identities=[ + TdxModuleIdentity( + id="TDX_03", + mrsigner=module_mrsigner, + attributes=b"\x01" * 8, # Expected: all 0x01 + attributes_mask=b"\xff" * 8, # Full mask + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=3, + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.UP_TO_DATE, + advisory_ids=[], + ), + ], + ) + ], + tcb_levels=[], + ) + + tee_tcb_svn = bytes([5, 3] + [0] * 14) + mr_signer_seam = module_mrsigner + seam_attributes = b"\x00" * 8 # Wrong! Expected 0x01 + + with pytest.raises(CollateralError, match="SEAM_ATTRIBUTES do not match"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_validate_module_identity_revoked_status_fails(self): + """Test that REVOKED module TCB status raises error.""" + module_mrsigner = b"\xaa" * 48 + + tcb_info = TcbInfo( + id="TDX", + version=3, + issue_date=datetime.now(timezone.utc), + next_update=datetime.now(timezone.utc) + timedelta(days=30), + fmspc="00a06f000000", + pce_id="0000", + tcb_type=0, + tcb_evaluation_data_number=17, + tdx_module=None, + tdx_module_identities=[ + TdxModuleIdentity( + id="TDX_03", + mrsigner=module_mrsigner, + attributes=b"\x00" * 8, + attributes_mask=b"\xff" * 8, + tcb_levels=[ + TcbLevel( + tcb=Tcb( + sgx_tcb_components=[TcbComponent(svn=0)] * 16, + pce_svn=0, + tdx_tcb_components=[TcbComponent(svn=0)] * 16, + isv_svn=3, + ), + tcb_date="2024-01-01T00:00:00Z", + tcb_status=TcbStatus.REVOKED, # REVOKED! + advisory_ids=[], + ), + ], + ) + ], + tcb_levels=[], + ) + + tee_tcb_svn = bytes([5, 3] + [0] * 14) + mr_signer_seam = module_mrsigner + seam_attributes = b"\x00" * 8 + + with pytest.raises(CollateralError, match="REVOKED"): + validate_tdx_module_identity( + tcb_info, tee_tcb_svn, mr_signer_seam, seam_attributes + ) + + def test_real_world_tee_tcb_svn_pattern(self): + """Test with a realistic TEE_TCB_SVN pattern from production.""" + # Real-world example: TDX module 3.x with minor version 5 + # and non-zero values in other positions + tcb_info = self._create_tcb_info_with_modules(["TDX_03"]) + + # Realistic 16-byte TEE_TCB_SVN: + # [0]=5 (minor), [1]=3 (major), [2]=0, [3]=0, ... + # Some positions might have non-zero values for platform TCB + tee_tcb_svn = bytes([5, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + result = get_tdx_module_identity(tcb_info, tee_tcb_svn) + assert result is not None + assert result.id == "TDX_03" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_tdx_verify.py b/tests/test_tdx_verify.py new file mode 100644 index 0000000..fb8c9f1 --- /dev/null +++ b/tests/test_tdx_verify.py @@ -0,0 +1,465 @@ +""" +Unit tests for TDX quote verification (verify_tdx.py). +""" + +import hashlib +import pytest +import struct + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, Prehashed + +from tinfoil.attestation.abi_tdx import ( + parse_quote, + QUOTE_HEADER_START, + QUOTE_BODY_END, + QE_REPORT_SIZE, + HEADER_SIZE, + TD_QUOTE_BODY_SIZE, + CERT_DATA_TYPE_QE_REPORT, + CERT_DATA_TYPE_PCK_CERT_CHAIN, +) +from tinfoil.attestation.verify_tdx import ( + TdxVerificationError, + PCKCertificateChain, + extract_pck_cert_chain, + verify_pck_chain, + verify_quote_signature, + verify_qe_report_signature, + verify_qe_report_data_binding, + verify_tdx_quote, + _bytes_to_p256_pubkey, + _signature_to_der, +) +from tinfoil.attestation.cert_utils import parse_pem_chain +from tinfoil.attestation.intel_root_ca import get_intel_root_ca, INTEL_SGX_ROOT_CA_PEM + + +# ============================================================================= +# Test Fixtures - Cryptographic Key Generation +# ============================================================================= + +def generate_p256_keypair(): + """Generate a P-256 key pair for testing.""" + private_key = ec.generate_private_key(ec.SECP256R1()) + public_key = private_key.public_key() + return private_key, public_key + + +def public_key_to_raw_bytes(public_key: ec.EllipticCurvePublicKey) -> bytes: + """Convert P-256 public key to raw 64-byte format (X || Y).""" + # Get uncompressed point (0x04 || X || Y) + uncompressed = public_key.public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint, + ) + # Strip the 0x04 prefix + return uncompressed[1:] + + +def sign_message(private_key: ec.EllipticCurvePrivateKey, message: bytes) -> bytes: + """Sign a message and return raw R||S signature (64 bytes). + + This hashes the message with SHA256 before signing. + """ + # Sign with SHA256 + der_signature = private_key.sign(message, ec.ECDSA(hashes.SHA256())) + + # Convert DER to raw R||S + r, s = decode_dss_signature(der_signature) + return r.to_bytes(32, byteorder='big') + s.to_bytes(32, byteorder='big') + + +def sign_prehashed(private_key: ec.EllipticCurvePrivateKey, digest: bytes) -> bytes: + """Sign a pre-hashed digest and return raw R||S signature (64 bytes). + + Use this when the message has already been hashed with SHA256. + """ + # Sign the pre-hashed digest + der_signature = private_key.sign(digest, ec.ECDSA(Prehashed(hashes.SHA256()))) + + # Convert DER to raw R||S + r, s = decode_dss_signature(der_signature) + return r.to_bytes(32, byteorder='big') + s.to_bytes(32, byteorder='big') + + +# ============================================================================= +# Test Fixtures - Quote Building with Real Signatures +# ============================================================================= + +def build_header_bytes() -> bytes: + """Build a valid TDX header. + + Note: Bytes 8-11 are reserved. Some older specs labeled these as + QE_SVN/PCE_SVN, but they are always zero in actual quotes. The real + SVN values come from PCK certificate extensions and QE Report. + """ + header = b'' + header += struct.pack(' bytes: + """Build a valid TdQuoteBody.""" + body = b'' + body += b'\x03' + b'\x00' * 15 # tee_tcb_svn + body += b'\xaa' * 48 # mr_seam + body += b'\x00' * 48 # mr_signer_seam + body += b'\x00' * 8 # seam_attributes + body += b'\x00\x00\x00\x10\x00\x00\x00\x00' # td_attributes + body += b'\xe7\x02\x06\x00\x00\x00\x00\x00' # xfam + body += b'\x11' * 48 # mr_td + body += b'\x00' * 48 # mr_config_id + body += b'\x00' * 48 # mr_owner + body += b'\x00' * 48 # mr_owner_config + body += b'\x22' * 48 # rtmr0 + body += b'\x33' * 48 # rtmr1 + body += b'\x44' * 48 # rtmr2 + body += b'\x00' * 48 # rtmr3 + body += b'\xab' * 32 + b'\xcd' * 32 # report_data + assert len(body) == TD_QUOTE_BODY_SIZE + return body + + +def build_qe_report_bytes(report_data: bytes = None) -> bytes: + """Build a QE report with specified report_data.""" + if report_data is None: + report_data = b'\x00' * 64 + + report = b'' + report += b'\x00' * 16 # cpu_svn + report += struct.pack(' bytes: + """Build the signed data section of the quote.""" + # PCK cert chain (type 5) + pck_chain = struct.pack(' bytes: + """Generate a self-signed certificate for testing.""" + from cryptography.x509.oid import NameOID + from cryptography.hazmat.primitives.asymmetric import ec + import datetime + + private_key = ec.generate_private_key(ec.SECP256R1()) + + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, cn), + ]) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=365)) + .sign(private_key, hashes.SHA256()) + ) + + return cert.public_bytes(serialization.Encoding.PEM) + + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + +class TestBytesToP256PubKey: + """Test _bytes_to_p256_pubkey helper.""" + + def test_valid_key(self): + """Test converting valid key bytes.""" + _, public_key = generate_p256_keypair() + raw_bytes = public_key_to_raw_bytes(public_key) + + result = _bytes_to_p256_pubkey(raw_bytes) + assert isinstance(result, ec.EllipticCurvePublicKey) + + def test_wrong_size(self): + """Test that wrong size raises error.""" + with pytest.raises(TdxVerificationError, match="expected 64"): + _bytes_to_p256_pubkey(b'\x00' * 63) + + +class TestSignatureToDer: + """Test _signature_to_der helper.""" + + def test_valid_signature(self): + """Test converting valid signature.""" + private_key, _ = generate_p256_keypair() + raw_sig = sign_message(private_key, b"test message") + + der_sig = _signature_to_der(raw_sig) + assert isinstance(der_sig, bytes) + assert len(der_sig) > 64 # DER is longer due to encoding + + def test_wrong_size(self): + """Test that wrong size raises error.""" + with pytest.raises(TdxVerificationError, match="expected 64"): + _signature_to_der(b'\x00' * 63) + + +class TestParsePemChain: + """Test parse_pem_chain helper.""" + + def test_parse_single_cert(self): + """Test parsing single certificate.""" + cert_pem = _generate_self_signed_cert_pem("Test") + certs = parse_pem_chain(cert_pem) + assert len(certs) == 1 + + def test_parse_multiple_certs(self): + """Test parsing multiple concatenated certificates.""" + cert1 = _generate_self_signed_cert_pem("Cert1") + cert2 = _generate_self_signed_cert_pem("Cert2") + cert3 = _generate_self_signed_cert_pem("Cert3") + + chain = cert1 + cert2 + cert3 + certs = parse_pem_chain(chain) + assert len(certs) == 3 + + def test_parse_with_whitespace(self): + """Test parsing with leading/trailing whitespace.""" + cert_pem = b'\n\n' + _generate_self_signed_cert_pem("Test") + b'\n\n' + certs = parse_pem_chain(cert_pem) + assert len(certs) == 1 + + +# ============================================================================= +# Quote Signature Verification Tests +# ============================================================================= + +class TestVerifyQuoteSignature: + """Test quote signature verification.""" + + def test_valid_signature(self): + """Test verification of valid quote signature.""" + raw_quote, attest_priv, _ = build_signed_quote_with_keys() + quote = parse_quote(raw_quote) + + # Should not raise + verify_quote_signature(quote, raw_quote) + + def test_tampered_body(self): + """Test that tampered body fails verification.""" + raw_quote, _, _ = build_signed_quote_with_keys() + + # Tamper with the body + tampered = bytearray(raw_quote) + tampered[100] ^= 0xFF # Flip a byte in the body + tampered = bytes(tampered) + + quote = parse_quote(tampered) + + with pytest.raises(TdxVerificationError, match="signature verification failed"): + verify_quote_signature(quote, tampered) + + +# ============================================================================= +# QE Report Data Binding Tests +# ============================================================================= + +class TestVerifyQeReportDataBinding: + """Test QE report data binding verification.""" + + def test_valid_binding(self): + """Test verification of valid binding.""" + raw_quote, _, _ = build_signed_quote_with_keys() + quote = parse_quote(raw_quote) + + # Should not raise + verify_qe_report_data_binding(quote) + + def test_tampered_auth_data(self): + """Test that wrong auth data fails binding check.""" + # Build quote with specific auth data + attest_priv, attest_pub = generate_p256_keypair() + + header = build_header_bytes() + body = build_body_bytes() + message_hash = hashlib.sha256(header + body).digest() + quote_signature = sign_prehashed(attest_priv, message_hash) + + attestation_key_raw = public_key_to_raw_bytes(attest_pub) + + # Use correct auth data for hash calculation + correct_auth_data = b'\x88' * 32 + qe_report_data_hash = hashlib.sha256(attestation_key_raw + correct_auth_data).digest() + qe_report_data = qe_report_data_hash + b'\x00' * 32 + + qe_report = build_qe_report_bytes(qe_report_data) + qe_priv, _ = generate_p256_keypair() + qe_signature = sign_message(qe_priv, qe_report) + + # But embed WRONG auth data in the quote + wrong_auth_data = b'\x99' * 32 # Different! + + cert_chain = ( + _generate_self_signed_cert_pem("PCK") + + _generate_self_signed_cert_pem("Int") + + INTEL_SGX_ROOT_CA_PEM + ) + + signed_data = _build_signed_data_bytes( + signature=quote_signature, + attestation_key=attestation_key_raw, + qe_report=qe_report, + qe_signature=qe_signature, + qe_auth_data=wrong_auth_data, # Wrong! + cert_chain_pem=cert_chain, + ) + + raw_quote = header + body + struct.pack(' str: + """Create a base64-encoded attestation document.""" + if compress: + compressed = gzip.compress(raw_quote) + return base64.b64encode(compressed).decode() + return base64.b64encode(raw_quote).decode() + + +# ============================================================================= +# Unit Tests for verify_tdx_attestation +# ============================================================================= + +class TestVerifyTdxAttestation: + """Test verify_tdx_attestation function.""" + + def test_invalid_base64(self): + """Test that invalid base64 raises error.""" + with pytest.raises(TdxAttestationError, match="Failed to decode base64"): + verify_tdx_attestation("not valid base64!!!") + + def test_invalid_gzip(self): + """Test that invalid gzip raises error.""" + # Valid base64 but not gzip + doc = base64.b64encode(b"not gzipped").decode() + with pytest.raises(TdxAttestationError, match="Failed to decompress"): + verify_tdx_attestation(doc, is_compressed=True) + +# ============================================================================= +# Integration Tests with attestation.py +# ============================================================================= + +class TestDocumentVerify: + """Test Document.verify() with TDX format.""" + + def test_document_verify_tdx(self): + """Test that Document.verify() handles TDX format.""" + raw_quote, _, _ = build_signed_quote_with_keys() + doc = create_test_attestation_doc(raw_quote) + + document = Document( + format=PredicateType.TDX_GUEST_V2, + body=doc, + ) + + with patch('tinfoil.attestation.attestation_tdx.verify_tdx_attestation') as mock_verify: + from tinfoil.attestation.attestation_tdx import TdxAttestationResult + mock_verify.return_value = TdxAttestationResult( + quote=MagicMock(), + pck_chain=MagicMock(), + pck_extensions=MagicMock(), + collateral=None, + tcb_level=None, + measurements=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + tls_key_fp="ff" * 32, + hpke_public_key="ee" * 32, + ) + + result = document.verify() + + assert isinstance(result, Verification) + assert result.measurement.type == PredicateType.TDX_GUEST_V2 + assert len(result.measurement.registers) == 5 + assert result.public_key_fp == "ff" * 32 + + +class TestMeasurementComparison: + """Test Measurement.assert_equal() with TDX measurements.""" + + def test_tdx_same_measurements_equal(self): + """Test that identical TDX measurements are equal.""" + m1 = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + m2 = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + + # Should not raise + m1.assert_equal(m2) + + def test_tdx_different_measurements_not_equal(self): + """Test that different TDX measurements raise error.""" + from tinfoil.attestation import MeasurementMismatchError + + m1 = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + m2 = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["ff" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + + with pytest.raises(MeasurementMismatchError): + m1.assert_equal(m2) + + def test_multiplatform_vs_tdx(self): + """Test multiplatform measurement comparison with TDX.""" + # Multiplatform: [SNP_measurement, RTMR1, RTMR2] + multiplatform = Measurement( + type=PredicateType.SNP_TDX_MULTIPLATFORM_v1, + registers=["snp" * 16, "cc" * 48, "dd" * 48], + ) + + # TDX: [MRTD, RTMR0, RTMR1, RTMR2, RTMR3] + tdx = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + + # RTMR1 and RTMR2 should match + # multiplatform.registers[1] == tdx.registers[2] (RTMR1) + # multiplatform.registers[2] == tdx.registers[3] (RTMR2) + multiplatform.assert_equal(tdx) # Should not raise + + +class TestMeasurementStr: + """Test Measurement.__str__() for TDX.""" + + def test_tdx_measurement_str(self): + """Test string representation of TDX measurement.""" + m = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["aa" * 48, "bb" * 48, "cc" * 48, "dd" * 48, "00" * 48], + ) + + s = str(m) + assert "TDX_GUEST_V2" in s or "tdx-guest" in s + assert "mrtd=" in s + assert "rtmr0=" in s + assert "rtmr1=" in s + assert "rtmr2=" in s + assert "rtmr3=" in s + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_verification_failures.py b/tests/test_verification_failures.py new file mode 100644 index 0000000..0cc4f18 --- /dev/null +++ b/tests/test_verification_failures.py @@ -0,0 +1,298 @@ +""" +Tests to ensure verification failures are properly propagated. + +These tests ensure that if any verification step fails: +1. The error is raised, not silently ignored +2. No HTTP client is created +3. No connection to the enclave is made + +This guards against bugs where failed verification would still allow +connections to proceed. +""" + +import pytest +from unittest.mock import patch, MagicMock + +from tinfoil.client import SecureClient +from tinfoil.attestation import ( + Measurement, + PredicateType, + MeasurementMismatchError, + HardwareMeasurementError, + verify_tdx_hardware, + HardwareMeasurement, +) + + +class TestMeasurementMismatch: + """Tests that measurement mismatches raise errors and block connections.""" + + def test_measurement_equals_raises_on_mismatch(self): + """Measurement.assert_equal() must raise MeasurementMismatchError on mismatch.""" + m1 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["abc123"] + ) + m2 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["different"] + ) + + with pytest.raises(MeasurementMismatchError): + m1.assert_equal(m2) + + def test_measurement_equals_raises_on_register_count_mismatch(self): + """Measurement.assert_equal() must raise if register counts differ.""" + m1 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["abc123"] + ) + m2 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["abc123", "extra"] + ) + + with pytest.raises(MeasurementMismatchError): + m1.assert_equal(m2) + + def test_measurement_equals_passes_on_match(self): + """Measurement.assert_equal() must not raise if measurements match.""" + m1 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["abc123"] + ) + m2 = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["abc123"] + ) + + # Should not raise + m1.assert_equal(m2) + + +class TestHardwareMeasurementVerification: + """Tests that hardware measurement failures raise errors.""" + + def test_verify_hardware_raises_on_no_match(self): + """verify_tdx_hardware() must raise HardwareMeasurementError if no match.""" + hardware_measurements = [ + HardwareMeasurement( + id="platform1@digest1", + mrtd="known_mrtd_1", + rtmr0="known_rtmr0_1" + ), + HardwareMeasurement( + id="platform2@digest2", + mrtd="known_mrtd_2", + rtmr0="known_rtmr0_2" + ), + ] + + enclave_measurement = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["unknown_mrtd", "unknown_rtmr0", "rtmr1", "rtmr2", "rtmr3"] + ) + + with pytest.raises(HardwareMeasurementError, match="no matching hardware platform"): + verify_tdx_hardware(hardware_measurements, enclave_measurement) + + def test_verify_hardware_passes_on_match(self): + """verify_tdx_hardware() must return the matching measurement.""" + hardware_measurements = [ + HardwareMeasurement( + id="platform1@digest1", + mrtd="known_mrtd", + rtmr0="known_rtmr0" + ), + ] + + enclave_measurement = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["known_mrtd", "known_rtmr0", "rtmr1", "rtmr2", "rtmr3"] + ) + + result = verify_tdx_hardware(hardware_measurements, enclave_measurement) + assert result.id == "platform1@digest1" + + +class TestSecureClientVerificationFailures: + """Tests that SecureClient properly blocks on verification failures.""" + + @patch('tinfoil.client.fetch_attestation') + def test_attestation_failure_blocks_verify(self, mock_fetch): + """If attestation fetch fails, verify() must raise.""" + mock_fetch.side_effect = Exception("Attestation fetch failed") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(Exception, match="Attestation fetch failed"): + client.verify() + + @patch('tinfoil.client.fetch_attestation') + def test_attestation_verification_failure_blocks_verify(self, mock_fetch): + """If attestation verification fails, verify() must raise.""" + mock_doc = MagicMock() + mock_doc.verify.side_effect = ValueError("TDX attestation verification failed") + mock_fetch.return_value = mock_doc + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(ValueError, match="TDX attestation verification failed"): + client.verify() + + @patch('tinfoil.client.fetch_attestation') + @patch('tinfoil.client.fetch_latest_digest') + @patch('tinfoil.client.fetch_attestation_bundle') + @patch('tinfoil.client.verify_attestation') + def test_measurement_mismatch_blocks_verify( + self, mock_verify_att, mock_fetch_bundle, mock_fetch_digest, mock_fetch_attestation + ): + """If code measurements don't match runtime, verify() must raise.""" + # Setup mocks + mock_fetch_digest.return_value = "test_digest" + mock_fetch_bundle.return_value = {} + + # Runtime measurement from enclave + runtime_measurement = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["runtime_measurement"] + ) + mock_verification = MagicMock() + mock_verification.measurement = runtime_measurement + mock_doc = MagicMock() + mock_doc.verify.return_value = mock_verification + mock_fetch_attestation.return_value = mock_doc + + # Code measurement from sigstore (different!) + code_measurement = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["different_code_measurement"] + ) + mock_verify_att.return_value = code_measurement + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(MeasurementMismatchError): + client.verify() + + @patch('tinfoil.client.fetch_attestation') + @patch('tinfoil.client.fetch_latest_digest') + @patch('tinfoil.client.fetch_attestation_bundle') + @patch('tinfoil.client.verify_attestation') + @patch('tinfoil.client.fetch_latest_hardware_measurements') + @patch('tinfoil.client.verify_tdx_hardware') + def test_hardware_mismatch_blocks_verify( + self, mock_verify_hw, mock_fetch_hw, mock_verify_att, + mock_fetch_bundle, mock_fetch_digest, mock_fetch_attestation + ): + """If TDX hardware measurements don't match, verify() must raise.""" + # Setup mocks + mock_fetch_digest.return_value = "test_digest" + mock_fetch_bundle.return_value = {} + + # TDX runtime measurement from enclave + runtime_measurement = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["mrtd", "rtmr0", "rtmr1", "rtmr2", "rtmr3"] + ) + mock_verification = MagicMock() + mock_verification.measurement = runtime_measurement + mock_doc = MagicMock() + mock_doc.verify.return_value = mock_verification + mock_fetch_attestation.return_value = mock_doc + + # Code measurement from sigstore + code_measurement = Measurement( + type=PredicateType.TDX_GUEST_V2, + registers=["mrtd", "rtmr0", "rtmr1", "rtmr2", "rtmr3"] + ) + mock_verify_att.return_value = code_measurement + + # Hardware verification fails + mock_fetch_hw.return_value = [] + mock_verify_hw.side_effect = HardwareMeasurementError("no matching hardware platform") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(HardwareMeasurementError, match="no matching hardware platform"): + client.verify() + + @patch('tinfoil.client.fetch_attestation') + def test_http_client_not_created_on_verification_failure(self, mock_fetch): + """make_secure_http_client() must not return a client if verify() fails.""" + mock_fetch.side_effect = Exception("Attestation failed") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(Exception, match="Attestation failed"): + client.make_secure_http_client() + + # Verify ground_truth was never set + assert client.ground_truth is None + + @patch('tinfoil.client.fetch_attestation') + def test_async_http_client_not_created_on_verification_failure(self, mock_fetch): + """make_secure_async_http_client() must not return a client if verify() fails.""" + mock_fetch.side_effect = Exception("Attestation failed") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(Exception, match="Attestation failed"): + client.make_secure_async_http_client() + + # Verify ground_truth was never set + assert client.ground_truth is None + + @patch('tinfoil.client.fetch_attestation') + def test_get_http_client_calls_verify_first(self, mock_fetch): + """get_http_client() must call verify() before returning client.""" + mock_fetch.side_effect = Exception("Verification failed") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + + with pytest.raises(Exception, match="Verification failed"): + client.get_http_client() + + @patch('tinfoil.client.fetch_attestation') + def test_make_request_calls_verify_first(self, mock_fetch): + """make_request() must call verify() before making request.""" + import urllib.request + + mock_fetch.side_effect = Exception("Verification failed") + + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + req = urllib.request.Request("https://test.enclave.sh/api") + + with pytest.raises(Exception, match="Verification failed"): + client.make_request(req) + + +class TestDirectMeasurementVerification: + """Tests for direct measurement verification (no repo).""" + + @patch('tinfoil.client.fetch_attestation') + def test_snp_measurement_mismatch_raises(self, mock_fetch): + """If SNP measurement doesn't match provided measurement, must raise.""" + # Runtime measurement from enclave + runtime_measurement = Measurement( + type=PredicateType.SEV_GUEST_V2, + registers=["actual_snp_measurement"] + ) + mock_verification = MagicMock() + mock_verification.measurement = runtime_measurement + mock_doc = MagicMock() + mock_doc.verify.return_value = mock_verification + mock_fetch.return_value = mock_doc + + # Expect different measurement + client = SecureClient( + enclave="test.enclave.sh", + measurement={"snp_measurement": "expected_snp_measurement"} + ) + + with pytest.raises(ValueError, match="SNP measurement mismatch"): + client.verify() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_verification_failures_integration.py b/tests/test_verification_failures_integration.py new file mode 100644 index 0000000..5b99003 --- /dev/null +++ b/tests/test_verification_failures_integration.py @@ -0,0 +1,143 @@ +""" +Integration tests for verification failure handling. + +These tests verify that REAL verification failures are properly caught +at the integration level - not with mocks, but with actual enclaves +and mismatched repos. + +This guards against bugs like the Go verifier issue where hardware +measurement mismatches would silently continue instead of failing. +""" + +import pytest + +from tinfoil.client import SecureClient, get_router_address +from tinfoil.attestation import MeasurementMismatchError + +pytestmark = pytest.mark.integration + +# Router runs confidential-model-router, use gpt-oss as wrong repo +CORRECT_REPO = "tinfoilsh/confidential-model-router" +WRONG_REPO = "tinfoilsh/confidential-gpt-oss-120b-free" + + +@pytest.fixture(scope="module") +def router_enclave(): + """Fetch a router enclave address, skip all tests if unavailable.""" + try: + return get_router_address() + except Exception as e: + pytest.skip(f"Could not fetch router address: {e}") + + +class TestMeasurementMismatchIntegration: + """ + Tests that measurement mismatches between enclave and repo + are properly caught at the integration level. + """ + + def test_wrong_repo_fails_verification(self, router_enclave): + """ + Verifying an enclave against the WRONG repo must fail. + + This test: + 1. Gets a real router enclave + 2. Tries to verify it against gpt-oss-120b-free repo (WRONG) + 3. Expects MeasurementMismatchError + + This catches bugs where measurement comparison is skipped. + """ + print(f"\nTesting: {router_enclave}") + print(f"Against WRONG repo: {WRONG_REPO}") + print("Expected: MeasurementMismatchError") + + client = SecureClient(enclave=router_enclave, repo=WRONG_REPO) + + with pytest.raises(MeasurementMismatchError): + client.verify() + + print("✓ Correctly rejected mismatched measurements") + + def test_wrong_repo_blocks_http_client(self, router_enclave): + """ + make_secure_http_client() must fail if measurements don't match. + + This catches bugs where HTTP client is created despite failed verification. + """ + client = SecureClient(enclave=router_enclave, repo=WRONG_REPO) + + with pytest.raises(MeasurementMismatchError): + client.make_secure_http_client() + + # Ground truth should NOT be set if verification failed + assert client.ground_truth is None, \ + "ground_truth was set despite verification failure!" + + print("✓ HTTP client correctly blocked on mismatch") + + def test_correct_repo_passes_verification(self, router_enclave): + """ + Sanity check: correct repo should pass verification. + """ + client = SecureClient(enclave=router_enclave, repo=CORRECT_REPO) + ground_truth = client.verify() + + assert ground_truth is not None + assert ground_truth.public_key is not None + assert ground_truth.measurement is not None + + print(f"✓ Correct repo verified successfully") + print(f" Measurement: {ground_truth.measurement.fingerprint()[:32]}...") + + +class TestDirectMeasurementIntegration: + """ + Tests for the direct measurement verification path. + """ + + def test_wrong_pinned_measurement_fails(self, router_enclave): + """ + If user provides a specific measurement that doesn't match + the enclave, verification must fail. + """ + # This is clearly a fake measurement + fake_measurement = { + "snp_measurement": "0000000000000000000000000000000000000000000000000000000000000000" + } + + client = SecureClient(enclave=router_enclave, measurement=fake_measurement) + + with pytest.raises(ValueError, match="measurement mismatch"): + client.verify() + + print("✓ Correctly rejected fake pinned measurement") + + +class TestNoSilentFailures: + """ + Tests that verification failures are NEVER silently ignored. + """ + + def test_verification_required_before_request(self, router_enclave): + """ + Any attempt to use the client must trigger verification. + If verification would fail, the request must also fail. + """ + client = SecureClient(enclave=router_enclave, repo=WRONG_REPO) + + # Try to get HTTP client - should fail + with pytest.raises(MeasurementMismatchError): + client.get_http_client() + + # Try to make request - should also fail + import urllib.request + req = urllib.request.Request(f"https://{router_enclave}/health") + + with pytest.raises(MeasurementMismatchError): + client.make_request(req) + + print("✓ All client methods properly block on verification failure") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])