diff --git a/crates/devolutions-agent-shared/src/lib.rs b/crates/devolutions-agent-shared/src/lib.rs index a093fafcc..d9d193a6b 100644 --- a/crates/devolutions-agent-shared/src/lib.rs +++ b/crates/devolutions-agent-shared/src/lib.rs @@ -18,14 +18,20 @@ cfg_if! { const COMPANY_DIR: &str = "Devolutions"; const PROGRAM_DIR: &str = "Agent"; const APPLICATION_DIR: &str = "Devolutions\\Agent"; + + pub const GATEWAY_SERVICE_NAME: &str = "DevolutionsGateway"; } else if #[cfg(target_os = "macos")] { const COMPANY_DIR: &str = "Devolutions"; const PROGRAM_DIR: &str = "Agent"; const APPLICATION_DIR: &str = "Devolutions Agent"; + + pub const GATEWAY_SERVICE_NAME: &str = "devolutions-agent"; } else { const COMPANY_DIR: &str = "devolutions"; const PROGRAM_DIR: &str = "agent"; const APPLICATION_DIR: &str = "devolutions-agent"; + + pub const GATEWAY_SERVICE_NAME: &str = "devolutions-agent"; } } diff --git a/crates/win-api-wrappers/Cargo.toml b/crates/win-api-wrappers/Cargo.toml index 1892d7b54..0f9542a7d 100644 --- a/crates/win-api-wrappers/Cargo.toml +++ b/crates/win-api-wrappers/Cargo.toml @@ -45,6 +45,7 @@ features = [ "Win32_System_Rpc", "Win32_System_StationsAndDesktops", "Win32_System_SystemServices", + "Win32_System_Services", "Win32_System_Threading", "Win32_System_WinRT", "Win32_UI_Controls", diff --git a/crates/win-api-wrappers/src/lib.rs b/crates/win-api-wrappers/src/lib.rs index 722cb1c6d..9b147881d 100644 --- a/crates/win-api-wrappers/src/lib.rs +++ b/crates/win-api-wrappers/src/lib.rs @@ -18,6 +18,7 @@ mod lib_win { pub mod netmgmt; pub mod process; pub mod security; + pub mod service; pub mod thread; pub mod token; pub mod ui; diff --git a/crates/win-api-wrappers/src/service.rs b/crates/win-api-wrappers/src/service.rs new file mode 100644 index 000000000..6820213fa --- /dev/null +++ b/crates/win-api-wrappers/src/service.rs @@ -0,0 +1,160 @@ +use std::alloc::Layout; + +use thiserror::Error; +use windows::core::Owned; +use windows::Win32::Foundation::{ERROR_INSUFFICIENT_BUFFER, GENERIC_READ}; +use windows::Win32::System::Services::{ + OpenSCManagerW, OpenServiceW, QueryServiceConfigW, QueryServiceStatus, StartServiceW, QUERY_SERVICE_CONFIGW, + SC_HANDLE, SC_MANAGER_ALL_ACCESS, SERVICE_ALL_ACCESS, SERVICE_AUTO_START, SERVICE_BOOT_START, SERVICE_DEMAND_START, + SERVICE_DISABLED, SERVICE_QUERY_CONFIG, SERVICE_QUERY_STATUS, SERVICE_RUNNING, SERVICE_STATUS, + SERVICE_SYSTEM_START, +}; + +use crate::raw_buffer::RawBuffer; +use crate::utils::WideString; + +#[derive(Debug, Error)] +pub enum ServiceError { + #[error(transparent)] + WinAPI(#[from] windows::core::Error), +} + +pub type ServiceResult = Result; +pub struct ServiceManager { + handle: Owned, +} + +impl ServiceManager { + pub fn open_read() -> ServiceResult { + Self::open_with_access(GENERIC_READ.0) + } + + pub fn open_all_access() -> ServiceResult { + Self::open_with_access(SC_MANAGER_ALL_ACCESS) + } + + fn open_with_access(access: u32) -> ServiceResult { + // SAFETY: FFI call with no outstanding preconditions. + let handle = unsafe { OpenSCManagerW(None, None, access)? }; + + // SAFETY: On success, the handle returned by `OpenSCManagerW` is valid and owned by the + // caller. + let handle = unsafe { Owned::new(handle) }; + + Ok(Self { handle }) + } + + fn open_service_with_access(&self, service_name: &str, access: u32) -> ServiceResult { + let service_name = WideString::from(service_name); + + // SAFETY: + // - Value passed as hSCManager is valid as long as `ServiceManager` instance is alive. + // - service_name is a valid, null-terminated UTF-16 string allocated on the heap. + let handle = unsafe { OpenServiceW(*self.handle, service_name.as_pcwstr(), access)? }; + + // SAFETY: Handle returned by `OpenServiceW` is valid and needs to be closed after use, + // thus it is safe to take ownership of it via `Owned`. + let handle = unsafe { Owned::new(handle) }; + + Ok(Service { handle }) + } + + pub fn open_service_read(&self, service_name: &str) -> ServiceResult { + self.open_service_with_access(service_name, SERVICE_QUERY_CONFIG | SERVICE_QUERY_STATUS) + } + + pub fn open_service_all_access(&self, service_name: &str) -> ServiceResult { + self.open_service_with_access(service_name, SERVICE_ALL_ACCESS) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ServiceStartupMode { + Boot, + System, + Automatic, + Manual, + Disabled, +} + +pub struct Service { + handle: Owned, +} + +impl Service { + pub fn startup_mode(&self) -> ServiceResult { + let mut cbbufsize = 0u32; + let mut pcbbytesneeded = 0u32; + + // SAFETY: FFI call with no outstanding preconditions. + let result = unsafe { QueryServiceConfigW(*self.handle, None, 0, &mut pcbbytesneeded) }; + + match result { + Err(err) if err.code() == ERROR_INSUFFICIENT_BUFFER.to_hresult() => { + // Expected error, continue. + } + Err(err) => return Err(err.into()), + Ok(_) => panic!("QueryServiceConfigW should fail with ERROR_INSUFFICIENT_BUFFER"), + } + + // The most typical buffer we work with in Rust are homogeneous arrays of integers such + // as [u8] or Vec, but in Microsoft’s Win32 documentation, a `buffer` generally refers + // to a caller-allocated memory region that an API uses to either input or output data, and + // it is ultimately coerced into some other type with various alignment requirements. + // + // lpServiceConfig should point to aligned buffer that could hold a QUERY_SERVICE_CONFIGW + // structure. + let layout = Layout::from_size_align( + usize::try_from(pcbbytesneeded).expect("pcbbytesneeded < 8K as per MSDN"), + align_of::(), + ) + .expect("layout always satisfies from_size_align invariants"); + + // SAFETY: The layout initialization is checked using the Layout::from_size_align method. + let mut buffer = unsafe { RawBuffer::alloc_zeroed(layout).expect("OOM") }; + + // SAFETY: Buffer passed to `lpServiceConfig` have enough size to hold a + // QUERY_SERVICE_CONFIGW structure, as required size was queried and allocated above. + // Passed buffer have correct alignment to hold QUERY_SERVICE_CONFIGW structure. + unsafe { + // Pointer cast is valid, as `buffer` is allocated with correct alignment above. + #[expect(clippy::cast_ptr_alignment)] + QueryServiceConfigW( + *self.handle, + Some(buffer.as_mut_ptr().cast::()), + pcbbytesneeded, + &mut cbbufsize, + )? + }; + + // SAFETY: `QueryServiceConfigW` succeeded, thus `lpserviceconfig` is valid and contains + // a QUERY_SERVICE_CONFIGW structure. + let config = unsafe { buffer.as_ref_cast::() }; + + match config.dwStartType { + SERVICE_BOOT_START => Ok(ServiceStartupMode::Boot), + SERVICE_SYSTEM_START => Ok(ServiceStartupMode::System), + SERVICE_AUTO_START => Ok(ServiceStartupMode::Automatic), + SERVICE_DEMAND_START => Ok(ServiceStartupMode::Manual), + SERVICE_DISABLED => Ok(ServiceStartupMode::Disabled), + _ => panic!("WinAPI returned invalid service startup mode"), + } + } + + pub fn is_running(&self) -> ServiceResult { + let mut service_status = SERVICE_STATUS::default(); + + // SAFETY: hService is a valid handle. + // lpServiceStatus is a valid pointer to a stack-allocated SERVICE_STATUS structure. + unsafe { QueryServiceStatus(*self.handle, &mut service_status as *mut _)? }; + + Ok(service_status.dwCurrentState == SERVICE_RUNNING) + } + + pub fn start(&self) -> ServiceResult<()> { + // SAFETY: FFI call with no outstanding preconditions. + unsafe { StartServiceW(*self.handle, None)? }; + + Ok(()) + } +} diff --git a/devolutions-agent/src/session_manager.rs b/devolutions-agent/src/session_manager.rs index d1d9af116..3a33a945e 100644 --- a/devolutions-agent/src/session_manager.rs +++ b/devolutions-agent/src/session_manager.rs @@ -93,10 +93,6 @@ impl SessionManagerCtx { self.sessions.remove(&session.to_string()); } - fn get_session(&self, session: &Session) -> Option<&GatewaySession> { - self.sessions.get(&session.to_string()) - } - fn get_session_mut(&mut self, session: &Session) -> Option<&mut GatewaySession> { self.sessions.get_mut(&session.to_string()) } diff --git a/devolutions-agent/src/updater/error.rs b/devolutions-agent/src/updater/error.rs index f293c2a2d..d9ddf74f6 100644 --- a/devolutions-agent/src/updater/error.rs +++ b/devolutions-agent/src/updater/error.rs @@ -42,4 +42,8 @@ pub(crate) enum UpdaterError { Io(#[from] std::io::Error), #[error("process does not have required rights to install MSI")] NotElevated, + #[error("failed to query service state for `{product}`")] + QueryServiceState { product: Product, source: anyhow::Error }, + #[error("failed to start service for `{product}`")] + StartService { product: Product, source: anyhow::Error }, } diff --git a/devolutions-agent/src/updater/mod.rs b/devolutions-agent/src/updater/mod.rs index 7e248de9a..7ee963839 100644 --- a/devolutions-agent/src/updater/mod.rs +++ b/devolutions-agent/src/updater/mod.rs @@ -4,6 +4,7 @@ mod integrity; mod io; mod package; mod product; +mod product_actions; mod productinfo; mod security; @@ -25,6 +26,7 @@ use crate::config::ConfHandle; use integrity::validate_artifact_hash; use io::{download_binary, download_utf8, save_to_temp_file}; use package::{install_package, uninstall_package, validate_package}; +use product_actions::{build_product_actions, ProductUpdateActions}; use productinfo::DEVOLUTIONS_PRODUCTINFO_URL; use security::set_file_dacl; @@ -40,6 +42,7 @@ const PRODUCTS: &[Product] = &[Product::Gateway]; /// Context for updater task struct UpdaterCtx { product: Product, + actions: Box, conf: ConfHandle, } @@ -163,7 +166,11 @@ async fn update_product(conf: ConfHandle, product: Product, order: UpdateOrder) info!(%product, %target_version, %package_path, "Downloaded product Installer"); - let ctx = UpdaterCtx { product, conf }; + let mut ctx = UpdaterCtx { + product, + actions: build_product_actions(product), + conf, + }; if let Some(hash) = hash { validate_artifact_hash(&ctx, &package_data, &hash).context("failed to validate package file integrity")?; @@ -176,6 +183,8 @@ async fn update_product(conf: ConfHandle, product: Product, order: UpdateOrder) return Ok(()); } + ctx.actions.pre_update()?; + if let Some(downgrade) = order.downgrade { let installed_version = downgrade.installed_version; info!(%product, %installed_version, %target_version, "Downgrading product..."); @@ -192,6 +201,8 @@ async fn update_product(conf: ConfHandle, product: Product, order: UpdateOrder) .await .context("failed to install package")?; + ctx.actions.post_update()?; + info!(%product, %target_version, "Product updated!"); Ok(()) diff --git a/devolutions-agent/src/updater/package.rs b/devolutions-agent/src/updater/package.rs index 87b392a90..18d072447 100644 --- a/devolutions-agent/src/updater/package.rs +++ b/devolutions-agent/src/updater/package.rs @@ -43,14 +43,20 @@ async fn install_msi(ctx: &UpdaterCtx, path: &Utf8Path, log_path: &Utf8Path) -> info!("Installing MSI from path: {}", path); - let msi_install_result = tokio::process::Command::new("msiexec") + let mut msiexec_command = tokio::process::Command::new("msiexec"); + + msiexec_command .arg("/i") .arg(path.as_str()) .arg("/quiet") .arg("/l*v") - .arg(log_path.as_str()) - .status() - .await; + .arg(log_path.as_str()); + + for param in ctx.actions.get_msiexec_install_params() { + msiexec_command.arg(param); + } + + let msi_install_result = msiexec_command.status().await; if log_path.exists() { info!("MSI installation log: {log_path}"); diff --git a/devolutions-agent/src/updater/product_actions.rs b/devolutions-agent/src/updater/product_actions.rs new file mode 100644 index 000000000..2e2dcc0fd --- /dev/null +++ b/devolutions-agent/src/updater/product_actions.rs @@ -0,0 +1,88 @@ +use win_api_wrappers::service::{ServiceManager, ServiceStartupMode}; + +use crate::updater::{Product, UpdaterError}; + +const SERVICE_NAME: &str = "DevolutionsGateway"; + +/// Additional actions that need to be performed during product update process +pub(crate) trait ProductUpdateActions { + fn pre_update(&mut self) -> Result<(), UpdaterError>; + fn get_msiexec_install_params(&self) -> Vec; + fn post_update(&mut self) -> Result<(), UpdaterError>; +} + +/// Gateway specific update actions +#[derive(Default)] +struct GatewayUpdateActions { + service_was_running: bool, + service_startup_was_automatic: bool, +} + +impl GatewayUpdateActions { + fn pre_update_impl(&mut self) -> anyhow::Result<()> { + info!("Querying service state for Gateway"); + let service_manager = ServiceManager::open_read()?; + let service = service_manager.open_service_read(SERVICE_NAME)?; + + self.service_startup_was_automatic = service.startup_mode()? == ServiceStartupMode::Automatic; + self.service_was_running = service.is_running()?; + + info!( + "Service state for Gateway before update: running: {}, automatic_startup: {}", + self.service_was_running, self.service_startup_was_automatic + ); + + Ok(()) + } + + fn post_update_impl(&self) -> anyhow::Result<()> { + // Start service if it was running prior to the update, but service startup + // was set to manual. + if !self.service_startup_was_automatic && self.service_was_running { + info!("Starting Gateway service after update"); + + let service_manager = ServiceManager::open_all_access()?; + let service = service_manager.open_service_all_access(SERVICE_NAME)?; + service.start()?; + + info!("Gateway service started"); + } + + Ok(()) + } +} + +impl ProductUpdateActions for GatewayUpdateActions { + fn pre_update(&mut self) -> Result<(), UpdaterError> { + self.pre_update_impl() + .map_err(|source| UpdaterError::QueryServiceState { + product: Product::Gateway, + source, + }) + } + + fn get_msiexec_install_params(&self) -> Vec { + // When performing update, we want to make sure the service startup mode is restored to the + // previous state. (Installer sets Manual by default). + if self.service_startup_was_automatic { + info!("Adjusting MSIEXEC parameters for Gateway service startup mode"); + + return vec!["P.SERVICESTART=Automatic".to_string()]; + } + + Vec::new() + } + + fn post_update(&mut self) -> Result<(), UpdaterError> { + self.post_update_impl().map_err(|source| UpdaterError::StartService { + product: Product::Gateway, + source, + }) + } +} + +pub(crate) fn build_product_actions(product: Product) -> Box { + match product { + Product::Gateway => Box::new(GatewayUpdateActions::default()), + } +} diff --git a/devolutions-gateway/src/service.rs b/devolutions-gateway/src/service.rs index 0a5cf1631..57ed394d3 100644 --- a/devolutions-gateway/src/service.rs +++ b/devolutions-gateway/src/service.rs @@ -15,7 +15,8 @@ use std::time::Duration; use tap::prelude::*; use tokio::runtime::{self, Runtime}; -pub(crate) const SERVICE_NAME: &str = "devolutions-gateway"; +pub(crate) use devolutions_agent_shared::GATEWAY_SERVICE_NAME as SERVICE_NAME; + pub(crate) const DISPLAY_NAME: &str = "Devolutions Gateway"; pub(crate) const DESCRIPTION: &str = "Devolutions Gateway service";