Skip to content

Commit

Permalink
fix(agent): add logic to start Gateway service after update (#1172)
Browse files Browse the repository at this point in the history
This PR adds the following logic to the updater:
* Gateway service startup mode when correctly performing MSI install
* Start service if it was in "Manual" startup mode before installation, but in the "Running" state, to avoid unintuitive behavior
  • Loading branch information
pacmancoder authored Jan 9, 2025
1 parent 3b8667d commit 651d8cf
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 10 deletions.
6 changes: 6 additions & 0 deletions crates/devolutions-agent-shared/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@ cfg_if! {
const COMPANY_DIR: &str = "Devolutions";
const PROGRAM_DIR: &str = "Agent";
const APPLICATION_DIR: &str = "Devolutions\\Agent";

pub const GATEWAY_SERVICE_NAME: &str = "DevolutionsGateway";
} else if #[cfg(target_os = "macos")] {
const COMPANY_DIR: &str = "Devolutions";
const PROGRAM_DIR: &str = "Agent";
const APPLICATION_DIR: &str = "Devolutions Agent";

pub const GATEWAY_SERVICE_NAME: &str = "devolutions-agent";
} else {
const COMPANY_DIR: &str = "devolutions";
const PROGRAM_DIR: &str = "agent";
const APPLICATION_DIR: &str = "devolutions-agent";

pub const GATEWAY_SERVICE_NAME: &str = "devolutions-agent";
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/win-api-wrappers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ features = [
"Win32_System_Rpc",
"Win32_System_StationsAndDesktops",
"Win32_System_SystemServices",
"Win32_System_Services",
"Win32_System_Threading",
"Win32_System_WinRT",
"Win32_UI_Controls",
Expand Down
1 change: 1 addition & 0 deletions crates/win-api-wrappers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod lib_win {
pub mod netmgmt;
pub mod process;
pub mod security;
pub mod service;
pub mod thread;
pub mod token;
pub mod ui;
Expand Down
160 changes: 160 additions & 0 deletions crates/win-api-wrappers/src/service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use std::alloc::Layout;

use thiserror::Error;
use windows::core::Owned;
use windows::Win32::Foundation::{ERROR_INSUFFICIENT_BUFFER, GENERIC_READ};
use windows::Win32::System::Services::{
OpenSCManagerW, OpenServiceW, QueryServiceConfigW, QueryServiceStatus, StartServiceW, QUERY_SERVICE_CONFIGW,
SC_HANDLE, SC_MANAGER_ALL_ACCESS, SERVICE_ALL_ACCESS, SERVICE_AUTO_START, SERVICE_BOOT_START, SERVICE_DEMAND_START,
SERVICE_DISABLED, SERVICE_QUERY_CONFIG, SERVICE_QUERY_STATUS, SERVICE_RUNNING, SERVICE_STATUS,
SERVICE_SYSTEM_START,
};

use crate::raw_buffer::RawBuffer;
use crate::utils::WideString;

#[derive(Debug, Error)]
pub enum ServiceError {
#[error(transparent)]
WinAPI(#[from] windows::core::Error),
}

pub type ServiceResult<T> = Result<T, ServiceError>;
pub struct ServiceManager {
handle: Owned<SC_HANDLE>,
}

impl ServiceManager {
pub fn open_read() -> ServiceResult<Self> {
Self::open_with_access(GENERIC_READ.0)
}

pub fn open_all_access() -> ServiceResult<Self> {
Self::open_with_access(SC_MANAGER_ALL_ACCESS)
}

fn open_with_access(access: u32) -> ServiceResult<Self> {
// SAFETY: FFI call with no outstanding preconditions.
let handle = unsafe { OpenSCManagerW(None, None, access)? };

// SAFETY: On success, the handle returned by `OpenSCManagerW` is valid and owned by the
// caller.
let handle = unsafe { Owned::new(handle) };

Ok(Self { handle })
}

fn open_service_with_access(&self, service_name: &str, access: u32) -> ServiceResult<Service> {
let service_name = WideString::from(service_name);

// SAFETY:
// - Value passed as hSCManager is valid as long as `ServiceManager` instance is alive.
// - service_name is a valid, null-terminated UTF-16 string allocated on the heap.
let handle = unsafe { OpenServiceW(*self.handle, service_name.as_pcwstr(), access)? };

// SAFETY: Handle returned by `OpenServiceW` is valid and needs to be closed after use,
// thus it is safe to take ownership of it via `Owned`.
let handle = unsafe { Owned::new(handle) };

Ok(Service { handle })
}

pub fn open_service_read(&self, service_name: &str) -> ServiceResult<Service> {
self.open_service_with_access(service_name, SERVICE_QUERY_CONFIG | SERVICE_QUERY_STATUS)
}

pub fn open_service_all_access(&self, service_name: &str) -> ServiceResult<Service> {
self.open_service_with_access(service_name, SERVICE_ALL_ACCESS)
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServiceStartupMode {
Boot,
System,
Automatic,
Manual,
Disabled,
}

pub struct Service {
handle: Owned<SC_HANDLE>,
}

impl Service {
pub fn startup_mode(&self) -> ServiceResult<ServiceStartupMode> {
let mut cbbufsize = 0u32;
let mut pcbbytesneeded = 0u32;

// SAFETY: FFI call with no outstanding preconditions.
let result = unsafe { QueryServiceConfigW(*self.handle, None, 0, &mut pcbbytesneeded) };

match result {
Err(err) if err.code() == ERROR_INSUFFICIENT_BUFFER.to_hresult() => {
// Expected error, continue.
}
Err(err) => return Err(err.into()),
Ok(_) => panic!("QueryServiceConfigW should fail with ERROR_INSUFFICIENT_BUFFER"),
}

// The most typical buffer we work with in Rust are homogeneous arrays of integers such
// as [u8] or Vec<u8>, but in Microsoft’s Win32 documentation, a `buffer` generally refers
// to a caller-allocated memory region that an API uses to either input or output data, and
// it is ultimately coerced into some other type with various alignment requirements.
//
// lpServiceConfig should point to aligned buffer that could hold a QUERY_SERVICE_CONFIGW
// structure.
let layout = Layout::from_size_align(
usize::try_from(pcbbytesneeded).expect("pcbbytesneeded < 8K as per MSDN"),
align_of::<QUERY_SERVICE_CONFIGW>(),
)
.expect("layout always satisfies from_size_align invariants");

// SAFETY: The layout initialization is checked using the Layout::from_size_align method.
let mut buffer = unsafe { RawBuffer::alloc_zeroed(layout).expect("OOM") };

// SAFETY: Buffer passed to `lpServiceConfig` have enough size to hold a
// QUERY_SERVICE_CONFIGW structure, as required size was queried and allocated above.
// Passed buffer have correct alignment to hold QUERY_SERVICE_CONFIGW structure.
unsafe {
// Pointer cast is valid, as `buffer` is allocated with correct alignment above.
#[expect(clippy::cast_ptr_alignment)]
QueryServiceConfigW(
*self.handle,
Some(buffer.as_mut_ptr().cast::<QUERY_SERVICE_CONFIGW>()),
pcbbytesneeded,
&mut cbbufsize,
)?
};

// SAFETY: `QueryServiceConfigW` succeeded, thus `lpserviceconfig` is valid and contains
// a QUERY_SERVICE_CONFIGW structure.
let config = unsafe { buffer.as_ref_cast::<QUERY_SERVICE_CONFIGW>() };

match config.dwStartType {
SERVICE_BOOT_START => Ok(ServiceStartupMode::Boot),
SERVICE_SYSTEM_START => Ok(ServiceStartupMode::System),
SERVICE_AUTO_START => Ok(ServiceStartupMode::Automatic),
SERVICE_DEMAND_START => Ok(ServiceStartupMode::Manual),
SERVICE_DISABLED => Ok(ServiceStartupMode::Disabled),
_ => panic!("WinAPI returned invalid service startup mode"),
}
}

pub fn is_running(&self) -> ServiceResult<bool> {
let mut service_status = SERVICE_STATUS::default();

// SAFETY: hService is a valid handle.
// lpServiceStatus is a valid pointer to a stack-allocated SERVICE_STATUS structure.
unsafe { QueryServiceStatus(*self.handle, &mut service_status as *mut _)? };

Ok(service_status.dwCurrentState == SERVICE_RUNNING)
}

pub fn start(&self) -> ServiceResult<()> {
// SAFETY: FFI call with no outstanding preconditions.
unsafe { StartServiceW(*self.handle, None)? };

Ok(())
}
}
4 changes: 0 additions & 4 deletions devolutions-agent/src/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@ impl SessionManagerCtx {
self.sessions.remove(&session.to_string());
}

fn get_session(&self, session: &Session) -> Option<&GatewaySession> {
self.sessions.get(&session.to_string())
}

fn get_session_mut(&mut self, session: &Session) -> Option<&mut GatewaySession> {
self.sessions.get_mut(&session.to_string())
}
Expand Down
4 changes: 4 additions & 0 deletions devolutions-agent/src/updater/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,8 @@ pub(crate) enum UpdaterError {
Io(#[from] std::io::Error),
#[error("process does not have required rights to install MSI")]
NotElevated,
#[error("failed to query service state for `{product}`")]
QueryServiceState { product: Product, source: anyhow::Error },
#[error("failed to start service for `{product}`")]
StartService { product: Product, source: anyhow::Error },
}
13 changes: 12 additions & 1 deletion devolutions-agent/src/updater/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod integrity;
mod io;
mod package;
mod product;
mod product_actions;
mod productinfo;
mod security;

Expand All @@ -25,6 +26,7 @@ use crate::config::ConfHandle;
use integrity::validate_artifact_hash;
use io::{download_binary, download_utf8, save_to_temp_file};
use package::{install_package, uninstall_package, validate_package};
use product_actions::{build_product_actions, ProductUpdateActions};
use productinfo::DEVOLUTIONS_PRODUCTINFO_URL;
use security::set_file_dacl;

Expand All @@ -40,6 +42,7 @@ const PRODUCTS: &[Product] = &[Product::Gateway];
/// Context for updater task
struct UpdaterCtx {
product: Product,
actions: Box<dyn ProductUpdateActions + Send + Sync + 'static>,
conf: ConfHandle,
}

Expand Down Expand Up @@ -163,7 +166,11 @@ async fn update_product(conf: ConfHandle, product: Product, order: UpdateOrder)

info!(%product, %target_version, %package_path, "Downloaded product Installer");

let ctx = UpdaterCtx { product, conf };
let mut ctx = UpdaterCtx {
product,
actions: build_product_actions(product),
conf,
};

if let Some(hash) = hash {
validate_artifact_hash(&ctx, &package_data, &hash).context("failed to validate package file integrity")?;
Expand All @@ -176,6 +183,8 @@ async fn update_product(conf: ConfHandle, product: Product, order: UpdateOrder)
return Ok(());
}

ctx.actions.pre_update()?;

if let Some(downgrade) = order.downgrade {
let installed_version = downgrade.installed_version;
info!(%product, %installed_version, %target_version, "Downgrading product...");
Expand All @@ -192,6 +201,8 @@ async fn update_product(conf: ConfHandle, product: Product, order: UpdateOrder)
.await
.context("failed to install package")?;

ctx.actions.post_update()?;

info!(%product, %target_version, "Product updated!");

Ok(())
Expand Down
14 changes: 10 additions & 4 deletions devolutions-agent/src/updater/package.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,20 @@ async fn install_msi(ctx: &UpdaterCtx, path: &Utf8Path, log_path: &Utf8Path) ->

info!("Installing MSI from path: {}", path);

let msi_install_result = tokio::process::Command::new("msiexec")
let mut msiexec_command = tokio::process::Command::new("msiexec");

msiexec_command
.arg("/i")
.arg(path.as_str())
.arg("/quiet")
.arg("/l*v")
.arg(log_path.as_str())
.status()
.await;
.arg(log_path.as_str());

for param in ctx.actions.get_msiexec_install_params() {
msiexec_command.arg(param);
}

let msi_install_result = msiexec_command.status().await;

if log_path.exists() {
info!("MSI installation log: {log_path}");
Expand Down
88 changes: 88 additions & 0 deletions devolutions-agent/src/updater/product_actions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use win_api_wrappers::service::{ServiceManager, ServiceStartupMode};

use crate::updater::{Product, UpdaterError};

const SERVICE_NAME: &str = "DevolutionsGateway";

/// Additional actions that need to be performed during product update process
pub(crate) trait ProductUpdateActions {
fn pre_update(&mut self) -> Result<(), UpdaterError>;
fn get_msiexec_install_params(&self) -> Vec<String>;
fn post_update(&mut self) -> Result<(), UpdaterError>;
}

/// Gateway specific update actions
#[derive(Default)]
struct GatewayUpdateActions {
service_was_running: bool,
service_startup_was_automatic: bool,
}

impl GatewayUpdateActions {
fn pre_update_impl(&mut self) -> anyhow::Result<()> {
info!("Querying service state for Gateway");
let service_manager = ServiceManager::open_read()?;
let service = service_manager.open_service_read(SERVICE_NAME)?;

self.service_startup_was_automatic = service.startup_mode()? == ServiceStartupMode::Automatic;
self.service_was_running = service.is_running()?;

info!(
"Service state for Gateway before update: running: {}, automatic_startup: {}",
self.service_was_running, self.service_startup_was_automatic
);

Ok(())
}

fn post_update_impl(&self) -> anyhow::Result<()> {
// Start service if it was running prior to the update, but service startup
// was set to manual.
if !self.service_startup_was_automatic && self.service_was_running {
info!("Starting Gateway service after update");

let service_manager = ServiceManager::open_all_access()?;
let service = service_manager.open_service_all_access(SERVICE_NAME)?;
service.start()?;

info!("Gateway service started");
}

Ok(())
}
}

impl ProductUpdateActions for GatewayUpdateActions {
fn pre_update(&mut self) -> Result<(), UpdaterError> {
self.pre_update_impl()
.map_err(|source| UpdaterError::QueryServiceState {
product: Product::Gateway,
source,
})
}

fn get_msiexec_install_params(&self) -> Vec<String> {
// When performing update, we want to make sure the service startup mode is restored to the
// previous state. (Installer sets Manual by default).
if self.service_startup_was_automatic {
info!("Adjusting MSIEXEC parameters for Gateway service startup mode");

return vec!["P.SERVICESTART=Automatic".to_string()];
}

Vec::new()
}

fn post_update(&mut self) -> Result<(), UpdaterError> {
self.post_update_impl().map_err(|source| UpdaterError::StartService {
product: Product::Gateway,
source,
})
}
}

pub(crate) fn build_product_actions(product: Product) -> Box<dyn ProductUpdateActions + Sync + Send + 'static> {
match product {
Product::Gateway => Box::new(GatewayUpdateActions::default()),
}
}
3 changes: 2 additions & 1 deletion devolutions-gateway/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use std::time::Duration;
use tap::prelude::*;
use tokio::runtime::{self, Runtime};

pub(crate) const SERVICE_NAME: &str = "devolutions-gateway";
pub(crate) use devolutions_agent_shared::GATEWAY_SERVICE_NAME as SERVICE_NAME;

pub(crate) const DISPLAY_NAME: &str = "Devolutions Gateway";
pub(crate) const DESCRIPTION: &str = "Devolutions Gateway service";

Expand Down

0 comments on commit 651d8cf

Please sign in to comment.