diff --git a/Cargo.lock b/Cargo.lock index 45316f7..4431f6a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -544,9 +544,9 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-link" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-sys" diff --git a/wfassoc/Cargo.toml b/wfassoc/Cargo.toml index 7e0124f..20b926b 100644 --- a/wfassoc/Cargo.toml +++ b/wfassoc/Cargo.toml @@ -8,7 +8,7 @@ license = "SPDX:MIT" [dependencies] thiserror = { workspace = true } -windows-sys = { version = "0.60.2", features = ["Win32_Security", "Win32_System_SystemServices"]} -winreg = "0.55.0" +windows-sys = { version = "0.60.2", features = ["Win32_Security", "Win32_System_SystemServices"] } +winreg = { version = "0.55.0", features = ["transactions"] } regex = "1.11.3" uuid = "1.18.1" diff --git a/wfassoc/src/lib.rs b/wfassoc/src/lib.rs index ab44320..da276a4 100644 --- a/wfassoc/src/lib.rs +++ b/wfassoc/src/lib.rs @@ -4,36 +4,121 @@ #[cfg(not(target_os = "windows"))] compile_error!("Crate wfassoc is only supported on Windows."); -/// The expand of winreg crate according to our module requirements. -mod winregex; - -use regex::Regex; -use std::fmt::Display; -use std::str::FromStr; -use std::sync::LazyLock; +use std::ffi::OsStr; +use std::path::{Path, PathBuf}; use thiserror::Error as TeError; use winreg::RegKey; +use winreg::enums::{ + HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, HKEY_LOCAL_MACHINE, KEY_READ, KEY_WRITE, +}; +use winreg::transaction::Transaction; // region: Error Types /// All possible error occurs in this crate. #[derive(Debug, TeError)] -pub enum Error { - #[error( - "can not register because lack essential privilege. please consider running with Administrator role" - )] +pub enum WfError { + #[error("no administrative privilege")] NoPrivilege, - #[error("{0}")] - Register(#[from] std::io::Error), - #[error("{0}")] - BadFileExt(#[from] ParseFileExtError), - #[error("{0}")] - BadProgId(#[from] ParseProgIdError), + #[error("error occurs when manipulating with Registry: {0}")] + BadRegOper(#[from] std::io::Error), + #[error("given full path to application is invalid")] + BadFullAppPath, + #[error("failed when casting path or OS string into string")] + BadOsStrCast, +} + +/// The result type used in this crate. +pub type WfResult = Result; + +// endregion + +// region: Scope and View + +/// The scope where wfassoc will register and unregister application. +#[derive(Debug, Copy, Clone)] +pub enum Scope { + /// Scope for current user. + User, + /// Scope for all users under this computer. + System, +} + +/// The error occurs when cast View into Scope. +#[derive(Debug, TeError)] +#[error("hybrid view can not be cast into any scope")] +pub struct TryFromViewError {} + +impl TryFromViewError { + fn new() -> Self { + Self {} + } +} + +impl TryFrom for Scope { + type Error = TryFromViewError; + + fn try_from(value: View) -> Result { + match value { + View::User => Ok(Self::User), + View::System => Ok(Self::System), + View::Hybrid => Err(TryFromViewError::new()), + } + } +} + +impl Scope { + /// Check whether we have enough privilege when operating in current scope. + /// If we have, return true, otherwise false. + fn has_privilege(&self) -> bool { + // If we operate on System, and we do not has privilege, + // we think we do not have privilege, otherwise, + // there is no privilege required. + !matches!(self, Self::System if !has_privilege()) + } +} + +/// The view when wfassoc querying file extension association. +#[derive(Debug, Copy, Clone)] +pub enum View { + /// The view of current user. + User, + /// The view of system. + System, + /// Hybrid view of User and System. + /// It can be seen as that we use System first and then use User to override any existing items. + Hybrid, +} + +impl From for View { + fn from(value: Scope) -> Self { + match value { + Scope::User => Self::User, + Scope::System => Self::System, + } + } } // endregion -// region: Privilege, Scope and View +// region: Utilities + +/// The println macro only works on Debug mode +/// for tracing the execution of some important functions. +macro_rules! debug_println { + // For no argument. + () => { + if cfg!(debug_assertions) { + println!(); + } + }; + // For one or more arguments like println!. + ($($arg:tt)*) => { + if cfg!(debug_assertions) { + println!($($arg)*); + } + }; +} /// Check whether current process has administrative privilege. /// @@ -41,7 +126,7 @@ pub enum Error { /// Return true if it is, otherwise false. /// /// Reference: https://learn.microsoft.com/en-us/windows/win32/api/securitybaseapi/nf-securitybaseapi-checktokenmembership -pub fn has_privilege() -> bool { +fn has_privilege() -> bool { use windows_sys::Win32::Foundation::HANDLE; use windows_sys::Win32::Security::{ AllocateAndInitializeSid, CheckTokenMembership, FreeSid, PSID, SECURITY_NT_AUTHORITY, @@ -88,500 +173,142 @@ pub fn has_privilege() -> bool { is_member != 0 } -/// The scope where wfassoc will register and unregister. -#[derive(Debug, Copy, Clone)] -pub enum Scope { - /// Scope for current user. - User, - /// Scope for all users under this computer. - System, +/// Try casting given &Path into &str. +fn path_to_str(path: &Path) -> WfResult<&str> { + path.to_str().ok_or(WfError::BadOsStrCast) } -/// The view when wfassoc querying infomations. -#[derive(Debug, Copy, Clone)] -pub enum View { - /// The view of current user. - User, - /// The view of system. - System, - /// Hybrid view of User and System. - /// It can be seen as that we use System first and then use User to override any existing items. - Hybrid, -} - -/// The error occurs when cast View into Scope. -#[derive(Debug, TeError)] -#[error("hybrid view can not be cast into any scope")] -pub struct TryFromViewError {} - -impl TryFromViewError { - fn new() -> Self { - Self {} - } -} - -impl From for View { - fn from(value: Scope) -> Self { - match value { - Scope::User => Self::User, - Scope::System => Self::System, - } - } -} - -impl TryFrom for Scope { - type Error = TryFromViewError; - - fn try_from(value: View) -> Result { - match value { - View::User => Ok(Self::User), - View::System => Ok(Self::System), - View::Hybrid => Err(TryFromViewError::new()), - } - } -} - -impl Scope { - /// Check whether we have enough privilege when operating in current scope. - /// If we have, simply return, otherwise return error. - fn check_privilege(&self) -> Result<(), Error> { - if matches!(self, Self::System if !has_privilege()) { - Err(Error::NoPrivilege) - } else { - Ok(()) - } - } +/// Try casting given &OsStr into &str. +fn osstr_to_str(osstr: &OsStr) -> WfResult<&str> { + osstr.to_str().ok_or(WfError::BadOsStrCast) } // endregion -// region: File Extension +// region: Registrar -/// The struct representing an file extension which must start with dot (`.`) -/// and followed by at least one arbitrary characters. -#[derive(Debug, Clone)] -pub struct FileExt { - /// The body of file extension (excluding dot). - inner: String, +/// The core registrar for register and unregister application. +pub struct Registrar { + /// The fully qualified path to the application. + full_path: PathBuf, } -impl FileExt { - pub fn new(file_ext: &str) -> Result { - Self::from_str(file_ext) - } -} - -/// The error occurs when try parsing string into FileExt. -#[derive(Debug, TeError)] -#[error("given file extension is invalid")] -pub struct ParseFileExtError {} - -impl ParseFileExtError { - fn new() -> Self { - Self {} - } -} - -impl Display for FileExt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, ".{}", self.inner) - } -} - -impl FromStr for FileExt { - type Err = ParseFileExtError; - - fn from_str(s: &str) -> Result { - static RE: LazyLock = LazyLock::new(|| Regex::new(r"^\.([^\.]+)$").unwrap()); - match RE.captures(s) { - Some(v) => Ok(Self { - inner: v[1].to_string(), - }), - None => Err(ParseFileExtError::new()), - } - } -} - -impl FileExt { - fn open_scope(&self, scope: Scope) -> Result { - use winreg::enums::{HKEY_CURRENT_USER, HKEY_LOCAL_MACHINE, KEY_READ, KEY_WRITE}; - - // check privilege - scope.check_privilege()?; - // get the root key - let hk = match scope { - Scope::User => RegKey::predef(HKEY_CURRENT_USER), - Scope::System => RegKey::predef(HKEY_LOCAL_MACHINE), - }; - // navigate to classes - let classes = hk.open_subkey_with_flags("Software\\Classes", KEY_READ | KEY_WRITE)?; - // okey - Ok(classes) - } - - fn open_view(&self, view: View) -> Result, Error> { - use winreg::enums::{HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, HKEY_LOCAL_MACHINE, KEY_READ}; - - // navigate to extension container - let hk = match view { - View::User => RegKey::predef(HKEY_CURRENT_USER), - View::System => RegKey::predef(HKEY_LOCAL_MACHINE), - View::Hybrid => RegKey::predef(HKEY_CLASSES_ROOT), - }; - let classes = match view { - View::User | View::System => { - hk.open_subkey_with_flags("Software\\Classes", KEY_READ)? - } - View::Hybrid => hk.open_subkey_with_flags("", KEY_READ)?, - }; - // check whether there is this ext - classes. - // open extension key if possible - let thisext = classes.open_subkey_with_flags(file_ext.to_string(), KEY_READ)?; - // okey - Ok(classes) - } - - pub fn get_current(&self, view: View) -> Option { - todo!() - } - - pub fn set_current(&mut self, scope: Scope, prog_id: Option<&ProgId>) -> Result<(), Error> { - scope.check_privilege()?; - todo!() - } - - pub fn iter_open_with(&self, view: View) -> Result, Error> { - let viewer = match self.open_view(view)? { - Some(viewer) => viewer, - None => return Ok(std::iter::empty::()), - }; - let it = winregex::iter_sz_keys(&viewer); - let it = winregex::exclude_default_key(it); - - Ok(it.map(|s| ProgId::from(s.as_str()))) - } - - pub fn insert_open_with(&mut self, scope: Scope, prog_id: &ProgId) -> Result<(), Error> { - scope.check_privilege()?; - todo!() - } - - pub fn flash_open_with( - &mut self, - scope: Scope, - prog_ids: impl Iterator, - ) -> Result<(), Error> { - scope.check_privilege()?; - todo!() - } -} - -/// The association infomations of specific file extension. -#[derive(Debug)] -pub struct FileExtAssoc { - default: String, - open_with_progids: Vec, -} - -impl FileExtAssoc { - fn new(file_ext: &FileExt, view: View) -> Option { - use winreg::RegKey; - use winreg::enums::{HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, HKEY_LOCAL_MACHINE, KEY_READ}; - - // navigate to extension container - let hk = match view { - View::User => RegKey::predef(HKEY_CURRENT_USER), - View::System => RegKey::predef(HKEY_LOCAL_MACHINE), - View::Hybrid => RegKey::predef(HKEY_CLASSES_ROOT), - }; - let classes = match view { - View::User | View::System => hk - .open_subkey_with_flags("Software\\Classes", KEY_READ) - .unwrap(), - View::Hybrid => hk.open_subkey_with_flags("", KEY_READ).unwrap(), - }; - - // open extension key if possible - let thisext = match classes.open_subkey_with_flags(file_ext.to_string(), KEY_READ) { - Ok(v) => v, - Err(_) => return None, - }; - - // fetch extension infos. - let default = thisext.get_value("").unwrap_or(String::new()); - let open_with_progids = - if let Ok(progids) = thisext.open_subkey_with_flags("OpenWithProdIds", KEY_READ) { - progids - .enum_keys() - .map(|x| x.unwrap()) - .filter(|k| !k.is_empty()) - .collect() - } else { - Vec::new() - }; - - Some(Self { - default, - open_with_progids, - }) - } - - pub fn get_default(&self) -> &str { - &self.default - } - - pub fn len_open_with_progid(&self) -> usize { - self.open_with_progids.len() - } - - pub fn iter_open_with_progids(&self) -> impl Iterator { - self.open_with_progids.iter().map(|s| s.as_str()) - } -} - -// endregion - -// region: Executable Resource - -// /// The struct representing an Windows executable resources path like -// /// `path_to_file.exe,1`. -// pub struct ExecRc { -// /// The path to binary for finding resources. -// binary: PathBuf, -// /// The inner index of resources. -// index: u32, -// } - -// impl ExecRc { -// pub fn new(res_str: &str) -> Result { -// static RE: LazyLock = LazyLock::new(|| Regex::new(r"^([^,]+),([0-9]+)$").unwrap()); -// let caps = RE.captures(res_str); -// if let Some(caps) = caps { -// let binary = PathBuf::from_str(&caps[1])?; -// let index = caps[2].parse::()?; -// Ok(Self { binary, index }) -// } else { -// Err(ParseExecRcError::NoCapture) -// } -// } -// } - -// /// The error occurs when try parsing string into ExecRc. -// #[derive(Debug, TeError)] -// #[error("given string is not a valid executable resource string")] -// pub enum ParseExecRcError { -// /// Given string is not matched with format. -// NoCapture, -// /// Fail to convert executable part into path. -// BadBinaryPath(#[from] std::convert::Infallible), -// /// Fail to convert index part into valid number. -// BadIndex(#[from] std::num::ParseIntError), -// } - -// impl FromStr for ExecRc { -// type Err = ParseExecRcError; - -// fn from_str(s: &str) -> Result { -// ExecRc::new(s) -// } -// } - -// impl Display for ExecRc { -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -// write!(f, "{},{}", self.binary.to_str().unwrap(), self.index) -// } -// } - -// endregion - -// region: Programmatic Identifiers (ProgId) - -/// The struct representing Programmatic Identifiers (ProgId). -/// -/// Because there is optional part in standard ProgId, and not all software developers -/// are willing to following Microsoft suggestions, there is no strict constaint for ProgId. -/// So this struct is actually an enum which holding any possible ProgId format. -/// -/// Reference: https://learn.microsoft.com/en-us/windows/win32/shell/fa-progids -pub enum ProgId { - Plain(String), - Loose(LosseProgId), - Strict(StrictProgId), -} - -impl From<&str> for ProgId { - fn from(s: &str) -> Self { - // match it for strict ProgId first - if let Ok(v) = StrictProgId::from_str(s) { - return Self::Strict(v); - } - // then match for loose ProgId - if let Ok(v) = LosseProgId::from_str(s) { - return Self::Loose(v); - } - // fallback with plain - Self::Plain(s.to_string()) - } -} - -impl Display for ProgId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ProgId::Plain(v) => v.fmt(f), - ProgId::Loose(v) => v.fmt(f), - ProgId::Strict(v) => v.fmt(f), - } - } -} - -/// The error occurs when parsing ProgId. -#[derive(Debug, TeError)] -#[error("given ProgId string is invalid")] -pub struct ParseProgIdError {} - -impl ParseProgIdError { - fn new() -> Self { - Self {} - } -} - -/// The ProgId similar with strict ProgId, but no version part. -pub struct LosseProgId { - vendor: String, - component: String, -} - -impl LosseProgId { - pub fn new(vendor: &str, component: &str) -> Self { +impl Registrar { + /// Create a new registrar for following operations. + pub fn new(full_path: &Path) -> Self { Self { - vendor: vendor.to_string(), - component: component.to_string(), - } - } - - pub fn get_vendor(&self) -> &str { - &self.vendor - } - - pub fn get_component(&self) -> &str { - &self.component - } -} - -impl FromStr for LosseProgId { - type Err = ParseProgIdError; - - fn from_str(s: &str) -> Result { - static RE: LazyLock = - LazyLock::new(|| Regex::new(r"^([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)$").unwrap()); - let caps = RE.captures(s); - if let Some(caps) = caps { - let vendor = &caps[1]; - let component = &caps[2]; - Ok(Self::new(vendor, component)) - } else { - Err(ParseProgIdError::new()) + full_path: full_path.to_path_buf(), } } } -impl Display for LosseProgId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}.{}", self.vendor, self.component) +impl Registrar { + const APP_PATHS: &str = "Software\\Microsoft\\Windows\\CurrentVersion\\App Paths"; + const APPLICATIONS: &str = "Software\\Classes\\Applications"; + + /// Register this application. + pub fn register(&self, scope: Scope) -> WfResult<()> { + // Fetch root key. + let hk = RegKey::predef(match scope { + Scope::User => HKEY_CURRENT_USER, + Scope::System => HKEY_LOCAL_MACHINE, + }); + // Fetch file name and start in path. + let file_name = self.extract_file_name()?; + let start_in = self.extract_start_in()?; + + // Create App Paths subkey + debug_println!("Adding App Paths subkey..."); + let subkey_parent = hk.open_subkey_with_flags(Self::APP_PATHS, KEY_READ)?; + let (subkey, _) = subkey_parent.create_subkey_with_flags(file_name, KEY_WRITE)?; + // Write App Paths values + subkey.set_value("", &path_to_str(&self.full_path)?)?; + subkey.set_value("Path", &osstr_to_str(&start_in)?)?; + + // Create Applications subkey + debug_println!("Adding Applications subkey..."); + let subkey_parent = hk.open_subkey_with_flags(Self::APPLICATIONS, KEY_READ)?; + let (subkey, _) = subkey_parent.create_subkey_with_flags(file_name, KEY_WRITE)?; + // Write Applications values + subkey.set_value("FriendlyAppName", &"WoW!")?; + + Ok(()) } -} -/// The ProgId exactly follows `[Vendor or Application].[Component].[Version]` format. -pub struct StrictProgId { - vendor: String, - component: String, - version: u32, -} + /// Unregister this application. + pub fn unregister(&self, scope: Scope) -> WfResult<()> { + // Fetch root key and file name. + let hk = RegKey::predef(match scope { + Scope::User => HKEY_CURRENT_USER, + Scope::System => HKEY_LOCAL_MACHINE, + }); + let file_name = self.extract_file_name()?; -impl StrictProgId { - pub fn new(vendor: &str, component: &str, version: u32) -> Self { - Self { - vendor: vendor.to_string(), - component: component.to_string(), - version, + // Remove App Paths subkey + debug_println!("Removing App Paths subkey..."); + let subkey_parent = hk.open_subkey_with_flags(Self::APP_PATHS, KEY_WRITE)?; + subkey_parent.delete_subkey_all(file_name)?; + + // Remove Applications subkey + debug_println!("Removing Applications subkey..."); + let subkey_parent = hk.open_subkey_with_flags(Self::APPLICATIONS, KEY_READ)?; + subkey_parent.delete_subkey_all(file_name)?; + + // Okey + Ok(()) + } + + /// Check whether this application has been registered. + /// + /// Please note that this is a rough check and do not validate any data. + pub fn is_registered(&self, scope: Scope) -> WfResult { + // Fetch root key and file name. + let hk = RegKey::predef(match scope { + Scope::User => HKEY_CURRENT_USER, + Scope::System => HKEY_LOCAL_MACHINE, + }); + let file_name = self.extract_file_name()?; + + // Check App Paths subkey. + debug_println!("Checking App Paths subkey..."); + let subkey_parent = hk.open_subkey_with_flags(Self::APP_PATHS, KEY_READ)?; + if let Err(_) = subkey_parent.open_subkey_with_flags(file_name, KEY_READ) { + return Ok(false); } - } - pub fn get_vendor(&self) -> &str { - &self.vendor - } - - pub fn get_component(&self) -> &str { - &self.component - } - - pub fn get_version(&self) -> u32 { - self.version - } -} - -impl FromStr for StrictProgId { - type Err = ParseProgIdError; - - fn from_str(s: &str) -> Result { - static RE: LazyLock = - LazyLock::new(|| Regex::new(r"^([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)\.([0-9]+)$").unwrap()); - let caps = RE.captures(s); - if let Some(caps) = caps { - let vendor = &caps[1]; - let component = &caps[2]; - let version = caps[3] - .parse::() - .map_err(|_| ParseProgIdError::new())?; - Ok(Self::new(vendor, component, version)) - } else { - Err(ParseProgIdError::new()) + // Check Application subkey. + debug_println!("Checking Applications subkey..."); + let subkey_parent = hk.open_subkey_with_flags(Self::APPLICATIONS, KEY_READ)?; + if let Err(_) = subkey_parent.open_subkey_with_flags(file_name, KEY_READ) { + return Ok(false); } + + // Both subkeys are roughly existing. + Ok(true) } } -impl Display for StrictProgId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}.{}.{}", self.vendor, self.component, self.version) +impl Registrar { + /// Extract the file name part from full path to application, + /// which was used in Registry path component. + fn extract_file_name(&self) -> WfResult<&OsStr> { + // Get the file name part and make sure it is not empty + self.full_path + .file_name() + .and_then(|p| if p.is_empty() { None } else { Some(p) }) + .ok_or(WfError::BadFullAppPath) + } + + /// Extract the start in path from full path to application, + /// which basically is the stem of full path. + fn extract_start_in(&self) -> WfResult<&OsStr> { + // Get parent part and make sure it is not empty + self.full_path + .parent() + .map(|p| p.as_os_str()) + .and_then(|p| if p.is_empty() { None } else { Some(p) }) + .ok_or(WfError::BadFullAppPath) } } // endregion - -// region: Program - -// /// The struct representing a complete Win32 program. -// pub struct Program { -// file_exts: Vec, -// } - -// impl Program { -// /// Create a program descriptor. -// pub fn new() -> Self { -// Self { -// file_exts: Vec::new(), -// } -// } -// } - -// impl Program { -// /// Register program in this computer -// pub fn register(&self, kind: RegisterKind) -> Result<(), Error> { -// todo!("pretend to register >_<...") -// } - -// /// Unregister program from this computer. -// pub fn unregister(&self) -> Result<(), Error> { -// todo!("pretend to unregister >_<...") -// } -// } - -// impl Program { -// /// Query file extension infos which this program want to associate with. -// pub fn query(&self) -> Result<(), Error> { -// todo!("pretend to query >_<...") -// } -// } - -// endregion diff --git a/wfassoc/src/program.rs b/wfassoc/src/program.rs deleted file mode 100644 index 9253c27..0000000 --- a/wfassoc/src/program.rs +++ /dev/null @@ -1,2 +0,0 @@ -use super::error::{Error, Result}; -use super::components::*; diff --git a/wfassoc/src/shit.rs b/wfassoc/src/shit.rs new file mode 100644 index 0000000..e81c03f --- /dev/null +++ b/wfassoc/src/shit.rs @@ -0,0 +1,582 @@ + +/// The expand of winreg crate according to our module requirements. +mod winregex; + +use regex::Regex; +use std::fmt::Display; +use std::str::FromStr; +use std::sync::LazyLock; +use thiserror::Error as TeError; +use winreg::RegKey; + +// region: Error Types + +/// All possible error occurs in this crate. +#[derive(Debug, TeError)] +pub enum Error { + #[error( + "can not register because lack essential privilege. please consider running with Administrator role" + )] + NoPrivilege, + #[error("{0}")] + Register(#[from] std::io::Error), + #[error("{0}")] + BadFileExt(#[from] ParseFileExtError), + #[error("{0}")] + BadProgId(#[from] ParseProgIdError), +} + +// endregion + +// region: Privilege, Scope and View + +/// Check whether current process has administrative privilege. +/// +/// It usually means that checking whether current process is running as Administrator. +/// Return true if it is, otherwise false. +/// +/// Reference: https://learn.microsoft.com/en-us/windows/win32/api/securitybaseapi/nf-securitybaseapi-checktokenmembership +pub fn has_privilege() -> bool { + use windows_sys::Win32::Foundation::HANDLE; + use windows_sys::Win32::Security::{ + AllocateAndInitializeSid, CheckTokenMembership, FreeSid, PSID, SECURITY_NT_AUTHORITY, + }; + use windows_sys::Win32::System::SystemServices::{ + DOMAIN_ALIAS_RID_ADMINS, SECURITY_BUILTIN_DOMAIN_RID, + }; + use windows_sys::core::BOOL; + + let nt_authority = SECURITY_NT_AUTHORITY.clone(); + let mut administrators_group: PSID = PSID::default(); + let success: BOOL = unsafe { + AllocateAndInitializeSid( + &nt_authority, + 2, + SECURITY_BUILTIN_DOMAIN_RID as u32, + DOMAIN_ALIAS_RID_ADMINS as u32, + 0, + 0, + 0, + 0, + 0, + 0, + &mut administrators_group, + ) + }; + + if success == 0 { + panic!("Win32 AllocateAndInitializeSid() failed"); + } + + let mut is_member: BOOL = BOOL::default(); + let success: BOOL = + unsafe { CheckTokenMembership(HANDLE::default(), administrators_group, &mut is_member) }; + + unsafe { + FreeSid(administrators_group); + } + + if success == 0 { + panic!("Win32 CheckTokenMembership() failed"); + } + + is_member != 0 +} + +/// The scope where wfassoc will register and unregister. +#[derive(Debug, Copy, Clone)] +pub enum Scope { + /// Scope for current user. + User, + /// Scope for all users under this computer. + System, +} + +/// The view when wfassoc querying infomations. +#[derive(Debug, Copy, Clone)] +pub enum View { + /// The view of current user. + User, + /// The view of system. + System, + /// Hybrid view of User and System. + /// It can be seen as that we use System first and then use User to override any existing items. + Hybrid, +} + +/// The error occurs when cast View into Scope. +#[derive(Debug, TeError)] +#[error("hybrid view can not be cast into any scope")] +pub struct TryFromViewError {} + +impl TryFromViewError { + fn new() -> Self { + Self {} + } +} + +impl From for View { + fn from(value: Scope) -> Self { + match value { + Scope::User => Self::User, + Scope::System => Self::System, + } + } +} + +impl TryFrom for Scope { + type Error = TryFromViewError; + + fn try_from(value: View) -> Result { + match value { + View::User => Ok(Self::User), + View::System => Ok(Self::System), + View::Hybrid => Err(TryFromViewError::new()), + } + } +} + +impl Scope { + /// Check whether we have enough privilege when operating in current scope. + /// If we have, simply return, otherwise return error. + fn check_privilege(&self) -> Result<(), Error> { + if matches!(self, Self::System if !has_privilege()) { + Err(Error::NoPrivilege) + } else { + Ok(()) + } + } +} + +// endregion + +// region: File Extension + +/// The struct representing an file extension which must start with dot (`.`) +/// and followed by at least one arbitrary characters. +#[derive(Debug, Clone)] +pub struct FileExt { + /// The body of file extension (excluding dot). + inner: String, +} + +impl FileExt { + pub fn new(file_ext: &str) -> Result { + Self::from_str(file_ext) + } +} + +/// The error occurs when try parsing string into FileExt. +#[derive(Debug, TeError)] +#[error("given file extension is invalid")] +pub struct ParseFileExtError {} + +impl ParseFileExtError { + fn new() -> Self { + Self {} + } +} + +impl Display for FileExt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, ".{}", self.inner) + } +} + +impl FromStr for FileExt { + type Err = ParseFileExtError; + + fn from_str(s: &str) -> Result { + static RE: LazyLock = LazyLock::new(|| Regex::new(r"^\.([^\.]+)$").unwrap()); + match RE.captures(s) { + Some(v) => Ok(Self { + inner: v[1].to_string(), + }), + None => Err(ParseFileExtError::new()), + } + } +} + +impl FileExt { + fn open_scope(&self, scope: Scope) -> Result { + use winreg::enums::{HKEY_CURRENT_USER, HKEY_LOCAL_MACHINE, KEY_READ, KEY_WRITE}; + + // check privilege + scope.check_privilege()?; + // get the root key + let hk = match scope { + Scope::User => RegKey::predef(HKEY_CURRENT_USER), + Scope::System => RegKey::predef(HKEY_LOCAL_MACHINE), + }; + // navigate to classes + let classes = hk.open_subkey_with_flags("Software\\Classes", KEY_READ | KEY_WRITE)?; + // okey + Ok(classes) + } + + fn open_view(&self, view: View) -> Result, Error> { + use winreg::enums::{HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, HKEY_LOCAL_MACHINE, KEY_READ}; + + // navigate to extension container + let hk = match view { + View::User => RegKey::predef(HKEY_CURRENT_USER), + View::System => RegKey::predef(HKEY_LOCAL_MACHINE), + View::Hybrid => RegKey::predef(HKEY_CLASSES_ROOT), + }; + let classes = match view { + View::User | View::System => { + hk.open_subkey_with_flags("Software\\Classes", KEY_READ)? + } + View::Hybrid => hk.open_subkey_with_flags("", KEY_READ)?, + }; + // check whether there is this ext + classes. + // open extension key if possible + let thisext = classes.open_subkey_with_flags(file_ext.to_string(), KEY_READ)?; + // okey + Ok(classes) + } + + pub fn get_current(&self, view: View) -> Option { + todo!() + } + + pub fn set_current(&mut self, scope: Scope, prog_id: Option<&ProgId>) -> Result<(), Error> { + scope.check_privilege()?; + todo!() + } + + pub fn iter_open_with(&self, view: View) -> Result, Error> { + let viewer = match self.open_view(view)? { + Some(viewer) => viewer, + None => return Ok(std::iter::empty::()), + }; + let it = winregex::iter_sz_keys(&viewer); + let it = winregex::exclude_default_key(it); + + Ok(it.map(|s| ProgId::from(s.as_str()))) + } + + pub fn insert_open_with(&mut self, scope: Scope, prog_id: &ProgId) -> Result<(), Error> { + scope.check_privilege()?; + todo!() + } + + pub fn flash_open_with( + &mut self, + scope: Scope, + prog_ids: impl Iterator, + ) -> Result<(), Error> { + scope.check_privilege()?; + todo!() + } +} + +/// The association infomations of specific file extension. +#[derive(Debug)] +pub struct FileExtAssoc { + default: String, + open_with_progids: Vec, +} + +impl FileExtAssoc { + fn new(file_ext: &FileExt, view: View) -> Option { + use winreg::RegKey; + use winreg::enums::{HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, HKEY_LOCAL_MACHINE, KEY_READ}; + + // navigate to extension container + let hk = match view { + View::User => RegKey::predef(HKEY_CURRENT_USER), + View::System => RegKey::predef(HKEY_LOCAL_MACHINE), + View::Hybrid => RegKey::predef(HKEY_CLASSES_ROOT), + }; + let classes = match view { + View::User | View::System => hk + .open_subkey_with_flags("Software\\Classes", KEY_READ) + .unwrap(), + View::Hybrid => hk.open_subkey_with_flags("", KEY_READ).unwrap(), + }; + + // open extension key if possible + let thisext = match classes.open_subkey_with_flags(file_ext.to_string(), KEY_READ) { + Ok(v) => v, + Err(_) => return None, + }; + + // fetch extension infos. + let default = thisext.get_value("").unwrap_or(String::new()); + let open_with_progids = + if let Ok(progids) = thisext.open_subkey_with_flags("OpenWithProdIds", KEY_READ) { + progids + .enum_keys() + .map(|x| x.unwrap()) + .filter(|k| !k.is_empty()) + .collect() + } else { + Vec::new() + }; + + Some(Self { + default, + open_with_progids, + }) + } + + pub fn get_default(&self) -> &str { + &self.default + } + + pub fn len_open_with_progid(&self) -> usize { + self.open_with_progids.len() + } + + pub fn iter_open_with_progids(&self) -> impl Iterator { + self.open_with_progids.iter().map(|s| s.as_str()) + } +} + +// endregion + +// region: Executable Resource + +// /// The struct representing an Windows executable resources path like +// /// `path_to_file.exe,1`. +// pub struct ExecRc { +// /// The path to binary for finding resources. +// binary: PathBuf, +// /// The inner index of resources. +// index: u32, +// } + +// impl ExecRc { +// pub fn new(res_str: &str) -> Result { +// static RE: LazyLock = LazyLock::new(|| Regex::new(r"^([^,]+),([0-9]+)$").unwrap()); +// let caps = RE.captures(res_str); +// if let Some(caps) = caps { +// let binary = PathBuf::from_str(&caps[1])?; +// let index = caps[2].parse::()?; +// Ok(Self { binary, index }) +// } else { +// Err(ParseExecRcError::NoCapture) +// } +// } +// } + +// /// The error occurs when try parsing string into ExecRc. +// #[derive(Debug, TeError)] +// #[error("given string is not a valid executable resource string")] +// pub enum ParseExecRcError { +// /// Given string is not matched with format. +// NoCapture, +// /// Fail to convert executable part into path. +// BadBinaryPath(#[from] std::convert::Infallible), +// /// Fail to convert index part into valid number. +// BadIndex(#[from] std::num::ParseIntError), +// } + +// impl FromStr for ExecRc { +// type Err = ParseExecRcError; + +// fn from_str(s: &str) -> Result { +// ExecRc::new(s) +// } +// } + +// impl Display for ExecRc { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// write!(f, "{},{}", self.binary.to_str().unwrap(), self.index) +// } +// } + +// endregion + +// region: Programmatic Identifiers (ProgId) + +/// The struct representing Programmatic Identifiers (ProgId). +/// +/// Because there is optional part in standard ProgId, and not all software developers +/// are willing to following Microsoft suggestions, there is no strict constaint for ProgId. +/// So this struct is actually an enum which holding any possible ProgId format. +/// +/// Reference: https://learn.microsoft.com/en-us/windows/win32/shell/fa-progids +pub enum ProgId { + Plain(String), + Loose(LosseProgId), + Strict(StrictProgId), +} + +impl From<&str> for ProgId { + fn from(s: &str) -> Self { + // match it for strict ProgId first + if let Ok(v) = StrictProgId::from_str(s) { + return Self::Strict(v); + } + // then match for loose ProgId + if let Ok(v) = LosseProgId::from_str(s) { + return Self::Loose(v); + } + // fallback with plain + Self::Plain(s.to_string()) + } +} + +impl Display for ProgId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ProgId::Plain(v) => v.fmt(f), + ProgId::Loose(v) => v.fmt(f), + ProgId::Strict(v) => v.fmt(f), + } + } +} + +/// The error occurs when parsing ProgId. +#[derive(Debug, TeError)] +#[error("given ProgId string is invalid")] +pub struct ParseProgIdError {} + +impl ParseProgIdError { + fn new() -> Self { + Self {} + } +} + +/// The ProgId similar with strict ProgId, but no version part. +pub struct LosseProgId { + vendor: String, + component: String, +} + +impl LosseProgId { + pub fn new(vendor: &str, component: &str) -> Self { + Self { + vendor: vendor.to_string(), + component: component.to_string(), + } + } + + pub fn get_vendor(&self) -> &str { + &self.vendor + } + + pub fn get_component(&self) -> &str { + &self.component + } +} + +impl FromStr for LosseProgId { + type Err = ParseProgIdError; + + fn from_str(s: &str) -> Result { + static RE: LazyLock = + LazyLock::new(|| Regex::new(r"^([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)$").unwrap()); + let caps = RE.captures(s); + if let Some(caps) = caps { + let vendor = &caps[1]; + let component = &caps[2]; + Ok(Self::new(vendor, component)) + } else { + Err(ParseProgIdError::new()) + } + } +} + +impl Display for LosseProgId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}", self.vendor, self.component) + } +} + +/// The ProgId exactly follows `[Vendor or Application].[Component].[Version]` format. +pub struct StrictProgId { + vendor: String, + component: String, + version: u32, +} + +impl StrictProgId { + pub fn new(vendor: &str, component: &str, version: u32) -> Self { + Self { + vendor: vendor.to_string(), + component: component.to_string(), + version, + } + } + + pub fn get_vendor(&self) -> &str { + &self.vendor + } + + pub fn get_component(&self) -> &str { + &self.component + } + + pub fn get_version(&self) -> u32 { + self.version + } +} + +impl FromStr for StrictProgId { + type Err = ParseProgIdError; + + fn from_str(s: &str) -> Result { + static RE: LazyLock = + LazyLock::new(|| Regex::new(r"^([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)\.([0-9]+)$").unwrap()); + let caps = RE.captures(s); + if let Some(caps) = caps { + let vendor = &caps[1]; + let component = &caps[2]; + let version = caps[3] + .parse::() + .map_err(|_| ParseProgIdError::new())?; + Ok(Self::new(vendor, component, version)) + } else { + Err(ParseProgIdError::new()) + } + } +} + +impl Display for StrictProgId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}.{}", self.vendor, self.component, self.version) + } +} + +// endregion + +// region: Program + +// /// The struct representing a complete Win32 program. +// pub struct Program { +// file_exts: Vec, +// } + +// impl Program { +// /// Create a program descriptor. +// pub fn new() -> Self { +// Self { +// file_exts: Vec::new(), +// } +// } +// } + +// impl Program { +// /// Register program in this computer +// pub fn register(&self, kind: RegisterKind) -> Result<(), Error> { +// todo!("pretend to register >_<...") +// } + +// /// Unregister program from this computer. +// pub fn unregister(&self) -> Result<(), Error> { +// todo!("pretend to unregister >_<...") +// } +// } + +// impl Program { +// /// Query file extension infos which this program want to associate with. +// pub fn query(&self) -> Result<(), Error> { +// todo!("pretend to query >_<...") +// } +// } + +// endregion diff --git a/wfassoc/src/winregex.rs b/wfassoc/src/winregex.rs deleted file mode 100644 index 02c2faf..0000000 --- a/wfassoc/src/winregex.rs +++ /dev/null @@ -1,96 +0,0 @@ -use winreg::RegKey; -use winreg::EnumKeys; -use winreg::enums::RegType; -// use thiserror::Error as TeError; - -// #[derive(Debug, TeError)] -// pub enum WinRegExError { -// #[error("{0}")] -// Io(#[from] std::io::Error) -// } - -// pub type WinRegExResult = Result; - -// region Iterate Keys without Error - -pub struct IterKeys<'a> { - iter_keys: EnumKeys<'a>, -} - -impl<'a> IterKeys<'a> { - fn new(regkey: &'a RegKey) -> Self { - Self { iter_keys: regkey.enum_keys() } - } -} - -impl<'a> Iterator for IterKeys<'a> { - type Item = String; - - fn next(&mut self) -> Option { - loop { - match self.iter_keys.next() { - Some(key) => match key { - Ok(key) => return Some(key), - Err(_) => continue, - }, - None => return None, - } - } - } -} - -pub fn iter_keys<'a>(regkey: &'a RegKey) -> IterKeys<'a> { - IterKeys::new(regkey) -} - -// endregion - -// region Iterate REG_SZ Keys - -pub struct IterSzKeys<'a> { - regkey: &'a RegKey, - iter_keys: IterKeys<'a>, -} - -impl<'a> IterSzKeys<'a> { - fn new(regkey: &'a RegKey) -> Self { - Self { regkey, iter_keys: IterKeys::new(regkey) } - } -} - -impl<'a> Iterator for IterSzKeys<'a> { - type Item = String; - - fn next(&mut self) -> Option { - loop { - if let Some(key) = self.iter_keys.next() { - // Check whether given key is REG_SZ. - match self.regkey.get_raw_value(&key) { - Ok(raw_value) => { - if matches!(raw_value.vtype, RegType::REG_SZ) { - return Some(key); - } else { - continue; - } - }, - Err(_) => continue, - } - } else { - return None; - } - } - } -} - -/// Untitled -/// -/// Given RegKey must has KEY_READ permission, otherwise the result may be inaccurate. -pub fn iter_sz_keys<'a>(regkey: &'a RegKey) -> IterSzKeys<'a> { - IterSzKeys::new(regkey) -} - -// endregion - -pub fn exclude_default_key(it: impl Iterator) -> impl Iterator { - it.filter(|x| !x.is_empty()) -}