diff --git a/proxy_agent_shared/Cargo.toml b/proxy_agent_shared/Cargo.toml index 44825b8d..30ca90e7 100644 --- a/proxy_agent_shared/Cargo.toml +++ b/proxy_agent_shared/Cargo.toml @@ -36,7 +36,7 @@ features = [ [target.'cfg(windows)'.dependencies] windows-service = "0.7.0" # windows NT service -winreg = "0.11.0" # windows reg read/write +winreg = "0.11" # windows reg read/write serde-xml-rs = "0.8.1" # xml Deserializer with xml attribute chrono = "0.4.41" # parse date time string diff --git a/proxy_agent_shared/src/windows.rs b/proxy_agent_shared/src/windows.rs index 29c7ed36..f70c8741 100644 --- a/proxy_agent_shared/src/windows.rs +++ b/proxy_agent_shared/src/windows.rs @@ -46,9 +46,16 @@ use windows_sys::Win32::System::Threading::{ use winreg::enums::*; use winreg::RegKey; -pub fn read_reg_int(key_name: &str, value_name: &str, default_value: Option) -> Option { +/// Open an existing HKLM registry key with `KEY_READ` permissions. +fn open_reg_key(key_name: &str) -> Result { let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); - match hklm.open_subkey(key_name) { + let reg_key = hklm.open_subkey(key_name)?; + Ok(reg_key) +} + +/// Read a REG_DWORD value from an existing HKLM key. +pub fn read_reg_int(key_name: &str, value_name: &str, default_value: Option) -> Option { + match open_reg_key(key_name) { Ok(key) => match key.get_value(value_name) { Ok(val) => return Some(val), Err(e) => { @@ -64,9 +71,8 @@ pub fn read_reg_int(key_name: &str, value_name: &str, default_value: Option } pub fn read_reg_string(key_name: &str, value_name: &str, default_value: String) -> String { - let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); - - if let Ok(key) = hklm.open_subkey(key_name) { + let key = open_reg_key(key_name); + if let Ok(key) = key { if let Ok(val) = key.get_value(value_name) { return val; } @@ -75,19 +81,91 @@ pub fn read_reg_string(key_name: &str, value_name: &str, default_value: String) default_value } -pub fn set_reg_string(key_name: &str, value_name: &str, value: String) -> Result<()> { +/// Open an existing HKLM registry key with `KEY_ALL_ACCESS` permissions, or create it if it doesn't exist. +fn open_create_reg_key(key_name: &str) -> Result { let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); let (key, _) = hklm.create_subkey(key_name)?; - key.set_value(value_name, &value)?; + Ok(key) +} + +/// Create (or open) an HKLM registry key. +pub fn add_registry_key(key_name: &str) -> Result<()> { + open_create_reg_key(key_name)?; + Ok(()) +} + +/// Write a REG_SZ value under an existing HKLM key. +/// Create an existing HKLM registry key if it doesn't exist. +pub fn set_reg_string(key_name: &str, value_name: &str, value: String) -> Result<()> { + open_create_reg_key(key_name)?.set_value(value_name, &value)?; + Ok(()) +} + +/// Write a REG_DWORD value under an existing HKLM key. +/// Create an existing HKLM registry key if it doesn't exist. +pub fn set_registry_entry_dword(key_name: &str, value_name: &str, value: u32) -> Result<()> { + open_create_reg_key(key_name)?.set_value(value_name, &value)?; Ok(()) } +/// Delete an existing HKLM registry key. pub fn remove_reg_key(key_name: &str) -> Result<()> { let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); hklm.delete_subkey_all(key_name)?; Ok(()) } +/// Open an existing HKLM registry key with `KEY_ALL_ACCESS` permissions. +fn open_reg_key_with_all_access(key_name: &str) -> Result { + let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); + let reg_key = hklm.open_subkey_with_flags(key_name, KEY_ALL_ACCESS)?; + Ok(reg_key) +} + +/// Delete a single named value from an existing HKLM key. +pub fn delete_registry_value(key_name: &str, value_name: &str) -> Result<()> { + open_reg_key_with_all_access(key_name)?.delete_value(value_name)?; + Ok(()) +} + +/// Delete a subkey from an HKLM key. When `recursive` is true, the entire +/// subtree is removed; otherwise only the leaf subkey is deleted. +/// +/// Best-effort: once the parent key opens, a missing or failed child +/// delete is ignored (`Ok`). Only failure to open the parent surfaces as +/// `Err`. +pub fn delete_registry_key(key_name: &str, sub_key_to_delete: &str, recursive: bool) -> Result<()> { + let parent = open_reg_key_with_all_access(key_name)?; + // Best-effort delete: any delete error (e.g. missing subkey) is ignored. + let _ = if recursive { + parent.delete_subkey_all(sub_key_to_delete) + } else { + parent.delete_subkey(sub_key_to_delete) + }; + Ok(()) +} + +/// Enumerate the immediate subkey names of an HKLM key. +pub fn get_subkey_names(key_name: &str) -> Result> { + let reg_key = open_reg_key(key_name)?; + let mut sub_keys = Vec::new(); + for sub_key in reg_key.enum_keys() { + sub_keys.push(sub_key?); + } + Ok(sub_keys) +} + +/// Enumerate the value names under an HKLM key. +pub fn get_value_names(key_name: &str) -> Result> { + let reg_key = open_reg_key(key_name)?; + let mut value_names = Vec::new(); + for value in reg_key.enum_values() { + let (name, _) = value?; + value_names.push(name); + } + Ok(value_names) +} + const OS_VERSION_REGISTRY_KEY: &str = "Software\\Microsoft\\Windows NT\\CurrentVersion"; const PRODUCT_NAME_VAL_STRING: &str = "ProductName"; const CURRENT_MAJOR_VERSION_NUMBER_STRING: &str = "CurrentMajorVersionNumber"; @@ -618,6 +696,78 @@ pub fn close_handler(handler: HANDLE) -> Result<()> { #[cfg(test)] mod tests { + fn test_key_base(suffix: &str) -> String { + format!( + "Software\\GuestProxyAgentUnitTests\\{}\\{}", + std::process::id(), + suffix + ) + } + + #[test] + fn get_subkey_names_test() { + let parent = test_key_base("get_subkey_names_test"); + let child1 = format!("{}\\ChildA", parent); + let child2 = format!("{}\\ChildB", parent); + + super::add_registry_key(&child1).unwrap(); + super::add_registry_key(&child2).unwrap(); + + let subkeys = super::get_subkey_names(&parent).unwrap(); + assert!(subkeys.iter().any(|name| name == "ChildA")); + assert!(subkeys.iter().any(|name| name == "ChildB")); + + super::remove_reg_key(&parent).unwrap(); + } + + #[test] + fn get_value_names_test() { + let key_name = test_key_base("get_value_names_test"); + + super::add_registry_key(&key_name).unwrap(); + super::set_reg_string(&key_name, "StringValue", "hello".to_string()).unwrap(); + super::set_registry_entry_dword(&key_name, "DwordValue", 42).unwrap(); + + let value_names = super::get_value_names(&key_name).unwrap(); + assert!(value_names.iter().any(|name| name == "StringValue")); + assert!(value_names.iter().any(|name| name == "DwordValue")); + + super::remove_reg_key(&key_name).unwrap(); + } + + #[test] + fn delete_registry_value_test() { + let key_name = test_key_base("delete_registry_value_test"); + let value_name = "ValueToDelete"; + + super::set_reg_string(&key_name, value_name, "to-delete".to_string()).unwrap(); + super::delete_registry_value(&key_name, value_name).unwrap(); + + let read_value = super::read_reg_string(&key_name, value_name, "default".to_string()); + assert_eq!("default", read_value); + + super::remove_reg_key(&key_name).unwrap(); + } + + #[test] + fn delete_registry_key_test() { + let parent = test_key_base("delete_registry_key_test"); + let child = format!("{}\\Child", parent); + let grandchild = format!("{}\\GrandChild", child); + + super::add_registry_key(&grandchild).unwrap(); + + super::delete_registry_key(&child, "GrandChild", false).unwrap(); + let subkeys = super::get_subkey_names(&child).unwrap(); + assert!(!subkeys.iter().any(|name| name == "GrandChild")); + + super::add_registry_key(&grandchild).unwrap(); + super::delete_registry_key(&parent, "Child", true).unwrap(); + let parent_subkeys = super::get_subkey_names(&parent).unwrap(); + assert!(!parent_subkeys.iter().any(|name| name == "Child")); + + super::remove_reg_key(&parent).unwrap(); + } #[test] fn get_os_version_tests() {