Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 56 additions & 60 deletions src/migtd/src/mig_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ mod v2 {
const SERVTD_ATTR_IGNORE_RTMR2: u64 = 0x100_0000_0000;
const SERVTD_ATTR_IGNORE_RTMR3: u64 = 0x200_0000_0000;

const SERVTD_TYPE_MIGTD: u16 = 0;

lazy_static! {
pub static ref VERIFIED_POLICY: Once<VerifiedPolicy<'static>> = Once::new();
}
Expand Down Expand Up @@ -461,11 +459,18 @@ mod v2 {
Ok(tdx_report)
}

/// Per GHCI 1.5: accepts TDINFO_STRUCT bytes directly (not full TDREPORT)
fn verify_servtd_hash(
/// Per GHCI 1.5: verifies TDINFO_STRUCT integrity against ServtdExt info hash.
///
/// Parses the TDINFO bytes, applies IGNORE masks from `servtd_attr`, computes
/// `SHA384(masked_tdinfo)`, and compares to `init_servtd_info_hash`.
/// Returns the parsed TdInfo on success.
///
/// This function is only called when SERVTD_EXT is supported (the sender
/// gates on TDCS.ATTRIBUTES bit 17 before sending init_tdinfo).
fn verify_servtd_info_hash(
tdinfo_bytes: &[u8],
servtd_attr: u64,
init_servtd_hash: &[u8],
init_servtd_info_hash: &[u8],
) -> Result<TdInfo, PolicyError> {
if tdinfo_bytes.len() < size_of::<TdInfo>() {
return Err(PolicyError::InvalidParameter);
Expand Down Expand Up @@ -518,33 +523,20 @@ mod v2 {
let info_hash =
digest_sha384(td_info.as_bytes()).map_err(|_| PolicyError::HashCalculation)?;

// Calculate ServTD hash: SHA384(info_hash || type || attr)
let mut buffer = [0u8; SHA384_DIGEST_SIZE + size_of::<u16>() + size_of::<u64>()];
let mut offset = 0;

buffer[offset..offset + SHA384_DIGEST_SIZE].copy_from_slice(&info_hash);
offset += SHA384_DIGEST_SIZE;

buffer[offset..offset + size_of::<u16>()].copy_from_slice(&SERVTD_TYPE_MIGTD.to_le_bytes());
offset += size_of::<u16>();

buffer[offset..offset + size_of::<u64>()].copy_from_slice(&servtd_attr.to_le_bytes());

let calculated_hash = digest_sha384(&buffer).map_err(|_| PolicyError::HashCalculation)?;

if calculated_hash.as_slice() != init_servtd_hash {
if info_hash.as_slice() != init_servtd_info_hash {
log::error!("verify_servtd_info_hash: HASH MISMATCH\n");
return Err(PolicyError::InvalidTdReport);
}

Ok(td_info)
}

/// Per GHCI 1.5: verifies TDINFO_STRUCT against servtd_ext hash
/// Per GHCI 1.5: verifies TDINFO_STRUCT against servtd_ext info hash
fn verify_init_tdinfo(
init_tdinfo: &[u8],
servtd_ext: &ServtdExt,
) -> Result<TdInfo, PolicyError> {
verify_servtd_hash(
verify_servtd_info_hash(
init_tdinfo,
u64::from_le_bytes(servtd_ext.init_attr),
&servtd_ext.init_servtd_info_hash,
Expand Down Expand Up @@ -923,47 +915,40 @@ mod v2 {
}

#[test]
fn test_verify_servtd_hash_valid() {
fn test_verify_servtd_info_hash_valid() {
// Build a 512-byte TDINFO_STRUCT with known content
let mut tdinfo_bytes = [0u8; 512];
tdinfo_bytes[0..8].copy_from_slice(&[0x01; 8]); // attributes
tdinfo_bytes[8..16].copy_from_slice(&[0x02; 8]); // xfam

// Compute expected hash: SHA384(SHA384(tdinfo) || type(u16) || attr(u64))
// Compute expected hash: SHA384(tdinfo)
let servtd_attr: u64 = 0;
let info_hash = digest_sha384(&tdinfo_bytes).unwrap();
let mut buffer = [0u8; SHA384_DIGEST_SIZE + size_of::<u16>() + size_of::<u64>()];
buffer[..SHA384_DIGEST_SIZE].copy_from_slice(&info_hash);
buffer[SHA384_DIGEST_SIZE..SHA384_DIGEST_SIZE + 2]
.copy_from_slice(&SERVTD_TYPE_MIGTD.to_le_bytes());
buffer[SHA384_DIGEST_SIZE + 2..SHA384_DIGEST_SIZE + 10]
.copy_from_slice(&servtd_attr.to_le_bytes());
let expected_hash = digest_sha384(&buffer).unwrap();

let result = verify_servtd_hash(&tdinfo_bytes, servtd_attr, &expected_hash);
let expected_hash = digest_sha384(&tdinfo_bytes).unwrap();

let result = verify_servtd_info_hash(&tdinfo_bytes, servtd_attr, &expected_hash);
assert!(result.is_ok());
let td_info = result.unwrap();
assert_eq!(td_info.attributes, [0x01; 8]);
assert_eq!(td_info.xfam, [0x02; 8]);
}

#[test]
fn test_verify_servtd_hash_wrong_hash() {
fn test_verify_servtd_info_hash_wrong_hash() {
let tdinfo_bytes = [0u8; 512];
let wrong_hash = [0xFFu8; 48];
let result = verify_servtd_hash(&tdinfo_bytes, 0, &wrong_hash);
let result = verify_servtd_info_hash(&tdinfo_bytes, 0, &wrong_hash);
assert!(result.is_err());
}

#[test]
fn test_verify_servtd_hash_short_input() {
fn test_verify_servtd_info_hash_short_input() {
let short = [0u8; 256]; // too small for TdInfo (512 bytes)
let result = verify_servtd_hash(&short, 0, &[0u8; 48]);
let result = verify_servtd_info_hash(&short, 0, &[0u8; 48]);
assert!(matches!(result, Err(PolicyError::InvalidParameter)));
}

#[test]
fn test_verify_servtd_hash_with_ignore_attributes() {
fn test_verify_servtd_info_hash_with_ignore_attributes() {
// Build TdInfo with non-zero attributes
let mut tdinfo_bytes = [0u8; 512];
tdinfo_bytes[0..8].copy_from_slice(&[0xFF; 8]); // attributes
Expand All @@ -972,21 +957,14 @@ mod v2 {
let servtd_attr = SERVTD_ATTR_IGNORE_ATTRIBUTES;
let mut zeroed = tdinfo_bytes;
zeroed[0..8].fill(0); // zero attributes for hash computation
let info_hash = digest_sha384(&zeroed).unwrap();
let mut buffer = [0u8; SHA384_DIGEST_SIZE + size_of::<u16>() + size_of::<u64>()];
buffer[..SHA384_DIGEST_SIZE].copy_from_slice(&info_hash);
buffer[SHA384_DIGEST_SIZE..SHA384_DIGEST_SIZE + 2]
.copy_from_slice(&SERVTD_TYPE_MIGTD.to_le_bytes());
buffer[SHA384_DIGEST_SIZE + 2..SHA384_DIGEST_SIZE + 10]
.copy_from_slice(&servtd_attr.to_le_bytes());
let expected_hash = digest_sha384(&buffer).unwrap();

let result = verify_servtd_hash(&tdinfo_bytes, servtd_attr, &expected_hash);
let expected_hash = digest_sha384(&zeroed).unwrap();

let result = verify_servtd_info_hash(&tdinfo_bytes, servtd_attr, &expected_hash);
assert!(result.is_ok());
}

#[test]
fn test_verify_servtd_hash_with_ignore_mrowner() {
fn test_verify_servtd_info_hash_with_ignore_mrowner() {
// Build TdInfo with non-zero mrowner at offset 112..160
let mut tdinfo_bytes = [0u8; 512];
tdinfo_bytes[112..160].copy_from_slice(&[0xAA; 48]); // mrowner
Expand All @@ -995,21 +973,39 @@ mod v2 {
let servtd_attr = SERVTD_ATTR_IGNORE_MROWNER;
let mut zeroed = tdinfo_bytes;
zeroed[112..160].fill(0);
let info_hash = digest_sha384(&zeroed).unwrap();
let mut buffer = [0u8; SHA384_DIGEST_SIZE + size_of::<u16>() + size_of::<u64>()];
buffer[..SHA384_DIGEST_SIZE].copy_from_slice(&info_hash);
buffer[SHA384_DIGEST_SIZE..SHA384_DIGEST_SIZE + 2]
.copy_from_slice(&SERVTD_TYPE_MIGTD.to_le_bytes());
buffer[SHA384_DIGEST_SIZE + 2..SHA384_DIGEST_SIZE + 10]
.copy_from_slice(&servtd_attr.to_le_bytes());
let expected_hash = digest_sha384(&buffer).unwrap();

let result = verify_servtd_hash(&tdinfo_bytes, servtd_attr, &expected_hash);
let expected_hash = digest_sha384(&zeroed).unwrap();

let result = verify_servtd_info_hash(&tdinfo_bytes, servtd_attr, &expected_hash);
assert!(result.is_ok());
// mrowner should be zeroed in the returned TdInfo
assert_eq!(result.unwrap().mrowner, [0u8; 48]);
}

#[test]
fn test_verify_servtd_info_hash_with_combined_ignore_flags() {
// Build TdInfo with non-zero content in multiple fields
let mut tdinfo_bytes = [0u8; 512];
tdinfo_bytes[0..8].copy_from_slice(&[0xFF; 8]); // attributes (masked)
tdinfo_bytes[8..16].copy_from_slice(&[0xEE; 8]); // xfam (masked)
tdinfo_bytes[16..64].copy_from_slice(&[0xBB; 48]); // mrtd (not masked)
tdinfo_bytes[112..160].copy_from_slice(&[0xCC; 48]); // mrowner (not masked)

// Compute hash with attributes and xfam zeroed, mrtd and mrowner intact
let servtd_attr = SERVTD_ATTR_IGNORE_ATTRIBUTES | SERVTD_ATTR_IGNORE_XFAM;
let mut zeroed = tdinfo_bytes;
zeroed[0..8].fill(0);
zeroed[8..16].fill(0);
let expected_hash = digest_sha384(&zeroed).unwrap();

let result = verify_servtd_info_hash(&tdinfo_bytes, servtd_attr, &expected_hash);
assert!(result.is_ok());
let td_info = result.unwrap();
assert_eq!(td_info.attributes, [0u8; 8]);
assert_eq!(td_info.xfam, [0u8; 8]);
assert_eq!(td_info.mrtd, [0xBB; 48]);
assert_eq!(td_info.mrowner, [0xCC; 48]);
}

#[test]
fn test_get_rtmrs_from_tdinfo() {
use tdx_tdcall::tdreport::TdInfo;
Expand Down
Loading