Skip to content
2 changes: 1 addition & 1 deletion deps/td-shim
249 changes: 205 additions & 44 deletions src/migtd/src/mig_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ mod v2 {
const SERVTD_ATTR_IGNORE_RTMR3: u64 = 0x200_0000_0000;

const SERVTD_TYPE_MIGTD: u16 = 0;
const TD_INFO_OFFSET: usize = 512;

lazy_static! {
pub static ref LOCAL_TCB_INFO: Once<PolicyEvaluationInfo> = Once::new();
Expand Down Expand Up @@ -136,13 +135,6 @@ mod v2 {
.ok_or(PolicyError::InvalidParameter)
}

pub fn get_init_tcb_evaluation_info(
init_report: &TdxReport,
init_policy: &VerifiedPolicy,
) -> Result<PolicyEvaluationInfo, PolicyError> {
setup_evaluation_data_with_tdreport(init_report, init_policy)
}

/// Get reference to the global verified policy
/// Returns None if the policy hasn't been initialized yet
pub fn get_verified_policy() -> Option<&'static VerifiedPolicy<'static>> {
Expand Down Expand Up @@ -263,13 +255,14 @@ mod v2 {
}

// Authenticate the migtd-old from migtd-new side
// Per GHCI 1.5: init_tdinfo is a TDINFO_STRUCT (not full TDREPORT),
// and there is no separate init_policy JSON blob.
pub fn authenticate_rebinding_old(
tdreport_src: &[u8],
event_log_src: &[u8],
mig_policy_src: &[u8],
init_policy: &[u8],
init_tdinfo: &[u8],
init_event_log: &[u8],
init_td_report: &[u8],
servtd_ext_src: &[u8],
) -> Result<Vec<u8>, PolicyError> {
let policy_issuer_chain = get_policy_issuer_chain().ok_or(PolicyError::InvalidParameter)?;
Expand All @@ -284,29 +277,30 @@ mod v2 {
)?;
let policy = get_verified_policy().ok_or(PolicyError::InvalidParameter)?;

// Verify the td report init / event log init / policy init
// Verify the init tdinfo against servtd_ext hash
let servtd_ext_src_obj =
ServtdExt::read_from_bytes(servtd_ext_src).ok_or(PolicyError::InvalidParameter)?;
let init_tdreport = verify_init_tdreport(init_td_report, &servtd_ext_src_obj)?;
let init_td_info = verify_init_tdinfo(init_tdinfo, &servtd_ext_src_obj)?;
let _engine_svn = policy
.servtd_tcb_mapping
.get_engine_svn_by_measurements(&Measurements::new_from_bytes(
&init_tdreport.td_info.mrtd,
&init_tdreport.td_info.rtmr0,
&init_tdreport.td_info.rtmr1,
&init_td_info.mrtd,
&init_td_info.rtmr0,
&init_td_info.rtmr1,
None,
None,
))
.ok_or(PolicyError::SvnMismatch)?;
let verified_policy_init = verify_policy_and_event_log(

// Verify init event log integrity against RTMRs from init tdinfo
verify_event_log(
init_event_log,
init_policy,
policy_issuer_chain,
&get_rtmrs_from_tdreport(&init_tdreport)?,
)?;
&get_rtmrs_from_tdinfo(&init_td_info)?,
)
.map_err(|_| PolicyError::InvalidEventLog)?;

let relative_reference =
get_init_tcb_evaluation_info(&init_tdreport, &verified_policy_init)?;
// Use local policy's tcb_mapping with init tdinfo measurements
let relative_reference = setup_evaluation_data_with_tdinfo(&init_td_info, policy)?;
policy.policy_data.evaluate_policy_common(
&evaluation_data_src,
&relative_reference,
Expand Down Expand Up @@ -456,51 +450,61 @@ mod v2 {
Ok(tdx_report)
}

/// Per GHCI 1.5: accepts TDINFO_STRUCT bytes directly (not full TDREPORT)
fn verify_servtd_hash(
servtd_report: &[u8],
tdinfo_bytes: &[u8],
servtd_attr: u64,
init_servtd_hash: &[u8],
) -> Result<TdxReport, PolicyError> {
if servtd_report.len() < TD_INFO_OFFSET + size_of::<TdInfo>() {
) -> Result<TdInfo, PolicyError> {
if tdinfo_bytes.len() < size_of::<TdInfo>() {
return Err(PolicyError::InvalidParameter);
}

// Extract TdInfo from the report
let mut td_report =
TdxReport::read_from_bytes(servtd_report).ok_or(PolicyError::InvalidTdReport)?;
// Parse TdInfo directly from bytes
let mut td_info = {
let mut uninit = core::mem::MaybeUninit::<TdInfo>::uninit();
unsafe {
core::ptr::copy_nonoverlapping(
tdinfo_bytes.as_ptr(),
uninit.as_mut_ptr() as *mut u8,
size_of::<TdInfo>(),
);
uninit.assume_init()
}
};

if (servtd_attr & SERVTD_ATTR_IGNORE_ATTRIBUTES) != 0 {
td_report.td_info.attributes.fill(0);
td_info.attributes.fill(0);
}
if (servtd_attr & SERVTD_ATTR_IGNORE_XFAM) != 0 {
td_report.td_info.xfam.fill(0);
td_info.xfam.fill(0);
}
if (servtd_attr & SERVTD_ATTR_IGNORE_MRTD) != 0 {
td_report.td_info.mrtd.fill(0);
td_info.mrtd.fill(0);
}
if (servtd_attr & SERVTD_ATTR_IGNORE_MRCONFIGID) != 0 {
td_report.td_info.mrconfig_id.fill(0);
td_info.mrconfig_id.fill(0);
}
if (servtd_attr & SERVTD_ATTR_IGNORE_MROWNER) != 0 {
td_report.td_info.mrowner.fill(0);
td_info.mrowner.fill(0);
}
if (servtd_attr & SERVTD_ATTR_IGNORE_MROWNERCONFIG) != 0 {
td_report.td_info.mrownerconfig.fill(0);
td_info.mrownerconfig.fill(0);
}
if (servtd_attr & SERVTD_ATTR_IGNORE_RTMR0) != 0 {
td_report.td_info.rtmr0.fill(0);
td_info.rtmr0.fill(0);
}
if (servtd_attr & SERVTD_ATTR_IGNORE_RTMR1) != 0 {
td_report.td_info.rtmr1.fill(0);
td_info.rtmr1.fill(0);
}
if (servtd_attr & SERVTD_ATTR_IGNORE_RTMR2) != 0 {
td_report.td_info.rtmr2.fill(0);
td_info.rtmr2.fill(0);
}
if (servtd_attr & SERVTD_ATTR_IGNORE_RTMR3) != 0 {
td_report.td_info.rtmr3.fill(0);
td_info.rtmr3.fill(0);
}

let info_hash = digest_sha384(td_report.td_info.as_bytes())
let info_hash = digest_sha384(td_info.as_bytes())
.map_err(|_| PolicyError::HashCalculation)?;

// Calculate ServTD hash: SHA384(info_hash || type || attr)
Expand All @@ -521,20 +525,62 @@ mod v2 {
return Err(PolicyError::InvalidTdReport);
}

Ok(td_report)
Ok(td_info)
}

fn verify_init_tdreport(
init_report: &[u8],
/// Per GHCI 1.5: verifies TDINFO_STRUCT against servtd_ext hash
fn verify_init_tdinfo(
init_tdinfo: &[u8],
servtd_ext: &ServtdExt,
) -> Result<TdxReport, PolicyError> {
) -> Result<TdInfo, PolicyError> {
verify_servtd_hash(
init_report,
init_tdinfo,
u64::from_le_bytes(servtd_ext.init_attr),
&servtd_ext.init_servtd_info_hash,
)
}

fn get_rtmrs_from_tdinfo(
td_info: &TdInfo,
) -> Result<[[u8; SHA384_DIGEST_SIZE]; 4], PolicyError> {
let mut rtmrs = [[0u8; SHA384_DIGEST_SIZE]; 4];
rtmrs[0].copy_from_slice(&td_info.rtmr0);
rtmrs[1].copy_from_slice(&td_info.rtmr1);
rtmrs[2].copy_from_slice(&td_info.rtmr2);
rtmrs[3].copy_from_slice(&td_info.rtmr3);
Ok(rtmrs)
}

fn setup_evaluation_data_with_tdinfo(
td_info: &TdInfo,
policy: &VerifiedPolicy,
) -> Result<PolicyEvaluationInfo, PolicyError> {
let migtd_svn = policy.servtd_tcb_mapping.get_engine_svn_by_measurements(
&Measurements::new_from_bytes(
&td_info.mrtd,
&td_info.rtmr0,
&td_info.rtmr1,
None,
None,
),
);

let migtd_tcb = migtd_svn.and_then(|svn| policy.servtd_identity.get_tcb_level_by_svn(svn));

Ok(PolicyEvaluationInfo {
tee_tcb_svn: None,
tcb_date: None,
tcb_status: None,
tcb_evaluation_number: None,
fmspc: None,
migtd_isvsvn: migtd_svn,
migtd_tcb_date: migtd_tcb.map(|tcb| tcb.tcb_date.clone()),
migtd_tcb_status: migtd_tcb.map(|tcb| tcb.tcb_status.clone()),
pck_crl_num: None,
root_ca_crl_num: None,
})
}

fn setup_evaluation_data(
fmspc: [u8; 6],
suppl_data: &[u8],
Expand Down Expand Up @@ -659,6 +705,121 @@ mod v2 {
let iso_date = unix_to_iso8601(timestamp).unwrap();
assert_eq!(iso_date, "2024-01-01T00:00:00Z");
}

#[test]
fn test_verify_servtd_hash_valid() {
// Build a 512-byte TDINFO_STRUCT with known content
let mut tdinfo_bytes = [0u8; 512];
tdinfo_bytes[0..8].copy_from_slice(&[0x01; 8]); // attributes
tdinfo_bytes[8..16].copy_from_slice(&[0x02; 8]); // xfam

// Compute expected hash: SHA384(SHA384(tdinfo) || type(u16) || attr(u64))
let servtd_attr: u64 = 0;
let info_hash = digest_sha384(&tdinfo_bytes).unwrap();
let mut buffer = [0u8; SHA384_DIGEST_SIZE + size_of::<u16>() + size_of::<u64>()];
buffer[..SHA384_DIGEST_SIZE].copy_from_slice(&info_hash);
buffer[SHA384_DIGEST_SIZE..SHA384_DIGEST_SIZE + 2]
.copy_from_slice(&SERVTD_TYPE_MIGTD.to_le_bytes());
buffer[SHA384_DIGEST_SIZE + 2..SHA384_DIGEST_SIZE + 10]
.copy_from_slice(&servtd_attr.to_le_bytes());
let expected_hash = digest_sha384(&buffer).unwrap();

let result = verify_servtd_hash(&tdinfo_bytes, servtd_attr, &expected_hash);
assert!(result.is_ok());
let td_info = result.unwrap();
assert_eq!(td_info.attributes, [0x01; 8]);
assert_eq!(td_info.xfam, [0x02; 8]);
}

#[test]
fn test_verify_servtd_hash_wrong_hash() {
let tdinfo_bytes = [0u8; 512];
let wrong_hash = [0xFFu8; 48];
let result = verify_servtd_hash(&tdinfo_bytes, 0, &wrong_hash);
assert!(result.is_err());
}

#[test]
fn test_verify_servtd_hash_short_input() {
let short = [0u8; 256]; // too small for TdInfo (512 bytes)
let result = verify_servtd_hash(&short, 0, &[0u8; 48]);
assert!(matches!(result, Err(PolicyError::InvalidParameter)));
}

#[test]
fn test_verify_servtd_hash_with_ignore_attributes() {
// Build TdInfo with non-zero attributes
let mut tdinfo_bytes = [0u8; 512];
tdinfo_bytes[0..8].copy_from_slice(&[0xFF; 8]); // attributes

// Compute hash with attributes zeroed (IGNORE_ATTRIBUTES flag)
let servtd_attr = SERVTD_ATTR_IGNORE_ATTRIBUTES;
let mut zeroed = tdinfo_bytes;
zeroed[0..8].fill(0); // zero attributes for hash computation
let info_hash = digest_sha384(&zeroed).unwrap();
let mut buffer = [0u8; SHA384_DIGEST_SIZE + size_of::<u16>() + size_of::<u64>()];
buffer[..SHA384_DIGEST_SIZE].copy_from_slice(&info_hash);
buffer[SHA384_DIGEST_SIZE..SHA384_DIGEST_SIZE + 2]
.copy_from_slice(&SERVTD_TYPE_MIGTD.to_le_bytes());
buffer[SHA384_DIGEST_SIZE + 2..SHA384_DIGEST_SIZE + 10]
.copy_from_slice(&servtd_attr.to_le_bytes());
let expected_hash = digest_sha384(&buffer).unwrap();

let result = verify_servtd_hash(&tdinfo_bytes, servtd_attr, &expected_hash);
assert!(result.is_ok());
}

#[test]
fn test_verify_servtd_hash_with_ignore_mrowner() {
// Build TdInfo with non-zero mrowner at offset 112..160
let mut tdinfo_bytes = [0u8; 512];
tdinfo_bytes[112..160].copy_from_slice(&[0xAA; 48]); // mrowner

// Compute hash with mrowner zeroed (IGNORE_MROWNER flag)
let servtd_attr = SERVTD_ATTR_IGNORE_MROWNER;
let mut zeroed = tdinfo_bytes;
zeroed[112..160].fill(0);
let info_hash = digest_sha384(&zeroed).unwrap();
let mut buffer = [0u8; SHA384_DIGEST_SIZE + size_of::<u16>() + size_of::<u64>()];
buffer[..SHA384_DIGEST_SIZE].copy_from_slice(&info_hash);
buffer[SHA384_DIGEST_SIZE..SHA384_DIGEST_SIZE + 2]
.copy_from_slice(&SERVTD_TYPE_MIGTD.to_le_bytes());
buffer[SHA384_DIGEST_SIZE + 2..SHA384_DIGEST_SIZE + 10]
.copy_from_slice(&servtd_attr.to_le_bytes());
let expected_hash = digest_sha384(&buffer).unwrap();

let result = verify_servtd_hash(&tdinfo_bytes, servtd_attr, &expected_hash);
assert!(result.is_ok());
// mrowner should be zeroed in the returned TdInfo
assert_eq!(result.unwrap().mrowner, [0u8; 48]);
}

#[test]
fn test_get_rtmrs_from_tdinfo() {
use tdx_tdcall::tdreport::TdInfo;
let mut tdinfo_bytes = [0u8; 512];
// RTMR offsets in TdInfo: rtmr0 at 208, rtmr1 at 256, rtmr2 at 304, rtmr3 at 352
tdinfo_bytes[208..256].copy_from_slice(&[0x01; 48]); // rtmr0
tdinfo_bytes[256..304].copy_from_slice(&[0x02; 48]); // rtmr1
tdinfo_bytes[304..352].copy_from_slice(&[0x03; 48]); // rtmr2
tdinfo_bytes[352..400].copy_from_slice(&[0x04; 48]); // rtmr3

let td_info = unsafe {
let mut uninit = core::mem::MaybeUninit::<TdInfo>::uninit();
core::ptr::copy_nonoverlapping(
tdinfo_bytes.as_ptr(),
uninit.as_mut_ptr() as *mut u8,
size_of::<TdInfo>(),
);
uninit.assume_init()
};

let rtmrs = get_rtmrs_from_tdinfo(&td_info).unwrap();
assert_eq!(rtmrs[0], [0x01; 48]);
assert_eq!(rtmrs[1], [0x02; 48]);
assert_eq!(rtmrs[2], [0x03; 48]);
assert_eq!(rtmrs[3], [0x04; 48]);
}
}

fn get_rtmrs_from_suppl_data(
Expand Down
4 changes: 2 additions & 2 deletions src/migtd/src/migration/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: BSD-2-Clause-Patent

#[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))]
#[cfg(all(feature = "main", feature = "vmcall-raw", feature = "policy_v2"))]
use crate::migration::rebinding::RebindingInfo;

use super::*;
Expand Down Expand Up @@ -257,7 +257,7 @@ pub struct RequestDataBuffer<'a> {
#[cfg(feature = "vmcall-raw")]
pub enum WaitForRequestResponse {
StartMigration(MigrationInformation),
#[cfg(feature = "policy_v2")]
#[cfg(all(feature = "main", feature = "policy_v2"))]
StartRebinding(RebindingInfo),
GetTdReport(ReportInfo),
EnableLogArea(EnableLogAreaInfo),
Expand Down
Loading
Loading