From d5fed1e5800670b0eb6c4e22c3bed2ee48498955 Mon Sep 17 00:00:00 2001 From: yyc12345 Date: Sun, 19 Oct 2025 14:10:21 +0800 Subject: [PATCH] feat(registry): add privilege check and improve ProgId handling - Add WFHasPrivilege function to check user privileges - Refactor ProgId structure to use standard format with optional version - Improve registry operations with safer key/value handling - Update dependencies to include Win32_System_Registry --- wfassoc/Cargo.toml | 7 ++- wfassoc/src/assoc.rs | 111 ++++++++++++------------------------ wfassoc/src/lib.rs | 29 +++++++--- wfassoc/src/winreg_extra.rs | 75 +++++++++++++++++++++--- wfassoc_dylib/src/lib.rs | 5 ++ 5 files changed, 135 insertions(+), 92 deletions(-) diff --git a/wfassoc/Cargo.toml b/wfassoc/Cargo.toml index 00a2eea..88b594d 100644 --- a/wfassoc/Cargo.toml +++ b/wfassoc/Cargo.toml @@ -8,7 +8,12 @@ license = "SPDX:MIT" [dependencies] thiserror = { workspace = true } -windows-sys = { version = "0.60.2", features = ["Win32_Security", "Win32_System_SystemServices", "Win32_UI_Shell"] } +windows-sys = { version = "0.60.2", features = [ + "Win32_Security", + "Win32_System_SystemServices", + "Win32_UI_Shell", + "Win32_System_Registry", +] } winreg = { version = "0.55.0", features = ["transactions"] } indexmap = "2.11.4" regex = "1.11.3" diff --git a/wfassoc/src/assoc.rs b/wfassoc/src/assoc.rs index 52fe283..ceb083f 100644 --- a/wfassoc/src/assoc.rs +++ b/wfassoc/src/assoc.rs @@ -79,32 +79,27 @@ impl FromStr for Ext { /// - https://learn.microsoft.com/en-us/windows/win32/shell/fa-progids /// - https://learn.microsoft.com/en-us/windows/win32/com/-progid--key pub enum ProgId { - Plain(String), - Loose(LosseProgId), - Strict(StrictProgId), + Other(String), + Std(StdProgId), } impl From<&str> for ProgId { fn from(s: &str) -> Self { - // match it for strict ProgId first - if let Ok(v) = StrictProgId::from_str(s) { - return Self::Strict(v); + // match it for standard ProgId first + if let Ok(v) = StdProgId::from_str(s) { + Self::Std(v) + } else { + // fallback with other + Self::Other(s.to_string()) } - // then match for loose ProgId - if let Ok(v) = LosseProgId::from_str(s) { - return Self::Loose(v); - } - // fallback with plain - Self::Plain(s.to_string()) } } impl Display for ProgId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ProgId::Plain(v) => v.fmt(f), - ProgId::Loose(v) => v.fmt(f), - ProgId::Strict(v) => v.fmt(f), + ProgId::Other(v) => v.fmt(f), + ProgId::Std(v) => v.fmt(f), } } } @@ -124,61 +119,18 @@ impl ParseProgIdError { } } -/// The ProgId similar with strict ProgId, but no version part. -pub struct LosseProgId { +/// The ProgId exactly follows Microsoft suggested +/// `[Vendor or Application].[Component].[Version]` format. +/// And `[Version]` part is optional. +pub struct StdProgId { vendor: String, component: String, + version: Option, } -impl LosseProgId { - pub fn new(vendor: &str, component: &str) -> Self { - Self { - vendor: vendor.to_string(), - component: component.to_string(), - } - } - - pub fn get_vendor(&self) -> &str { - &self.vendor - } - - pub fn get_component(&self) -> &str { - &self.component - } -} - -impl FromStr for LosseProgId { - type Err = ParseProgIdError; - - fn from_str(s: &str) -> Result { - static RE: LazyLock = - LazyLock::new(|| Regex::new(r"^([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)$").unwrap()); - let caps = RE.captures(s); - if let Some(caps) = caps { - let vendor = &caps[1]; - let component = &caps[2]; - Ok(Self::new(vendor, component)) - } else { - Err(ParseProgIdError::new(s)) - } - } -} - -impl Display for LosseProgId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}.{}", self.vendor, self.component) - } -} - -/// The ProgId exactly follows `[Vendor or Application].[Component].[Version]` format. -pub struct StrictProgId { - vendor: String, - component: String, - version: u32, -} - -impl StrictProgId { - pub fn new(vendor: &str, component: &str, version: u32) -> Self { +impl StdProgId { + /// Create a new standard ProgId. + pub fn new(vendor: &str, component: &str, version: Option) -> Self { Self { vendor: vendor.to_string(), component: component.to_string(), @@ -186,32 +138,40 @@ impl StrictProgId { } } + /// Get the vendor part of standard ProgId. pub fn get_vendor(&self) -> &str { &self.vendor } + /// Get the component part of standard ProgId. pub fn get_component(&self) -> &str { &self.component } - pub fn get_version(&self) -> u32 { + /// Get the version part of standard ProgId. + pub fn get_version(&self) -> Option { self.version } } -impl FromStr for StrictProgId { +impl FromStr for StdProgId { type Err = ParseProgIdError; fn from_str(s: &str) -> Result { static RE: LazyLock = - LazyLock::new(|| Regex::new(r"^([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)\.([0-9]+)$").unwrap()); + LazyLock::new(|| Regex::new(r"^([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)(\.([0-9]+))?$").unwrap()); let caps = RE.captures(s); if let Some(caps) = caps { let vendor = &caps[1]; let component = &caps[2]; - let version = caps[3] - .parse::() - .map_err(|_| ParseProgIdError::new(s))?; + let version = match caps.get(4) { + Some(sv) => Some( + sv.as_str() + .parse::() + .map_err(|_| ParseProgIdError::new(s))?, + ), + None => None, + }; Ok(Self::new(vendor, component, version)) } else { Err(ParseProgIdError::new(s)) @@ -219,9 +179,12 @@ impl FromStr for StrictProgId { } } -impl Display for StrictProgId { +impl Display for StdProgId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}.{}.{}", self.vendor, self.component, self.version) + match &self.version { + Some(version) => write!(f, "{}.{}.{}", self.vendor, self.component, version), + None => write!(f, "{}.{}", self.vendor, self.component), + } } } diff --git a/wfassoc/src/lib.rs b/wfassoc/src/lib.rs index 3e161bd..59dd723 100644 --- a/wfassoc/src/lib.rs +++ b/wfassoc/src/lib.rs @@ -19,6 +19,7 @@ use winreg::RegKey; use winreg::enums::{ HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, HKEY_LOCAL_MACHINE, KEY_READ, KEY_WRITE, }; +use assoc::{Ext, ProgId}; // region: Error Types @@ -138,7 +139,7 @@ pub struct Program { manners: IndexSet, /// The collection holding all file extensions supported by this program. /// The key is file estension and value is its associated manner for opening it. - exts: IndexMap, + exts: IndexMap, } impl Program { @@ -200,7 +201,7 @@ impl Program { } // Create extension from string - let ext = assoc::Ext::new(ext)?; + let ext = Ext::new(ext)?; // Backup a stringfied extension for error output. let ext_str = ext.to_string(); // Insert file extension @@ -382,9 +383,17 @@ impl Program { // Open key for this extension. // If there is no such key, return directly. - let subkey = classes.open_subkey_with_flags(ext.to_string(), KEY_WRITE)?; - // Delete the default key. - subkey.delete_value("")?; + if let Some(subkey) = + winreg_extra::try_open_subkey_with_flags(&classes, ext.to_string(), KEY_WRITE)? + { + // Only delete the default key if it is equal to our ProgId + if let Some(value) = winreg_extra::try_get_value::(&subkey, "")? { + if value == prog_id.to_string() { + // Delete the default key. + subkey.delete_value("")?; + } + } + } // Okey Ok(()) @@ -416,10 +425,12 @@ impl Program { } /// Build ProgId from identifier and given file extension. - fn build_prog_id(&self, ext: &assoc::Ext) -> assoc::ProgId { - let vendor = utilities::capitalize_first_ascii(&self.identifier); - let component = utilities::capitalize_first_ascii(ext.inner()); - assoc::ProgId::Loose(assoc::LosseProgId::new(&vendor, &component)) + fn build_prog_id(&self, ext: &Ext) -> ProgId { + ProgId::Std(assoc::StdProgId::new( + &self.identifier, + &utilities::capitalize_first_ascii(ext.inner()), + None + )) } } diff --git a/wfassoc/src/winreg_extra.rs b/wfassoc/src/winreg_extra.rs index 58c2d46..95c9c6f 100644 --- a/wfassoc/src/winreg_extra.rs +++ b/wfassoc/src/winreg_extra.rs @@ -1,8 +1,69 @@ //! This module expand `winreg` crate to make it more suit for this crate. +use std::ffi::OsStr; +use std::ops::Deref; +use std::ops::DerefMut; +use windows_sys::Win32::Foundation::ERROR_FILE_NOT_FOUND; +use windows_sys::Win32::System::Registry::REG_SAM_FLAGS; +use winreg::RegKey; +use winreg::types::FromRegValue; + +// region: Extra Operations + +/// Get the subkey with given name. +/// +/// If error occurs when fetching given subkey, it return `Err(...)`, +/// otherwise, it will return `Ok(Some(...))` if aubkey is existing, +/// or `Ok(None)` if there is no suchsub key. +/// +/// Comparing with the function provided by winreg, +/// it differ "no such subkey" error and other access error. +pub fn try_open_subkey_with_flags>( + regkey: &RegKey, + path: P, + perms: REG_SAM_FLAGS, +) -> std::io::Result> { + match regkey.open_subkey_with_flags(path, perms) { + Ok(v) => Ok(Some(v)), + Err(e) => match e.raw_os_error() { + Some(errno) => match errno as u32 { + ERROR_FILE_NOT_FOUND => Ok(None), + _ => Err(e) + } + _ => Err(e), + }, + } +} + +/// Get the value by given key. +/// +/// If error occurs when fetching given key, it return `Err(...)`, +/// otherwise, it will return `Ok(Some(...))` if key is existing, +/// or `Ok(None)` if there is no such key. +/// +/// Comparing with the function provided by winreg, +/// it differ "no such key" error and other access error. +pub fn try_get_value>( + regkey: &RegKey, + name: N, +) -> std::io::Result> { + match regkey.get_value::(name) { + Ok(v) => Ok(Some(v)), + Err(e) => match e.raw_os_error() { + Some(errno) => match errno as u32 { + ERROR_FILE_NOT_FOUND => Ok(None), + _ => Err(e) + } + _ => Err(e), + }, + } +} + +// endregion + // region: Expand String -/// The struct basically is the alias of String, but make a slight difference with it, +/// The struct basically is the alias of String, but make a slight difference with it, /// to make they are different when use it with String as generic argument. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ExpandString(String); @@ -12,22 +73,22 @@ impl ExpandString { pub fn new(s: String) -> Self { Self(s) } - + /// Create from &str pub fn from_str(s: &str) -> Self { Self(s.to_string()) } - + /// Get reference to internal String. pub fn as_str(&self) -> &str { &self.0 } - + /// Get mutable reference to internal String. pub fn as_mut_str(&mut self) -> &mut String { &mut self.0 } - + /// Comsule self, return internal String. pub fn into_inner(self) -> String { self.0 @@ -35,17 +96,15 @@ impl ExpandString { } // Implement Deref trait to make it can be used like &str -use std::ops::Deref; impl Deref for ExpandString { type Target = str; - + fn deref(&self) -> &Self::Target { &self.0 } } // Implement DerefMut trait -use std::ops::DerefMut; impl DerefMut for ExpandString { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 diff --git a/wfassoc_dylib/src/lib.rs b/wfassoc_dylib/src/lib.rs index 5661133..3f57957 100644 --- a/wfassoc_dylib/src/lib.rs +++ b/wfassoc_dylib/src/lib.rs @@ -90,6 +90,11 @@ pub extern "C" fn WFGetLastError() -> *const c_char { get_last_error() } +#[unsafe(no_mangle)] +pub extern "C" fn WFHasPrivilege() -> bool { + wfassoc::utilities::has_privilege() +} + #[unsafe(no_mangle)] pub extern "C" fn WFAdd(left: u32, right: u32, rv: *mut u32) -> bool { unsafe { *rv = left + right; }