Skip to content

Commit

Permalink
Refactorings to the PAM conversation handling (#961)
Browse files Browse the repository at this point in the history
  • Loading branch information
squell authored Jan 21, 2025
2 parents f470283 + 83db1db commit 00d0d3d
Showing 1 changed file with 34 additions and 96 deletions.
130 changes: 34 additions & 96 deletions src/pam/converse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,47 +35,7 @@ impl PamMessageStyle {
}
}

/// A PamMessage contains the data in a single message of a pam conversation
/// and contains the response to that message.
pub struct PamMessage {
pub msg: String,
pub style: PamMessageStyle,
response: Option<PamBuffer>,
}

impl PamMessage {
/// Set a response value to the message.
pub fn set_response(&mut self, resp: PamBuffer) {
self.response = Some(resp);
}
}

/// Contains the conversation messages and allows setting responses to
/// each of these messages.
///
/// Note that generally there will only be one message in each conversation
/// because of historical reasons, and instead multiple conversations will
/// be started for individual messages.
pub struct Conversation {
messages: Vec<PamMessage>,
}

impl Conversation {
/// Get a mutable iterator of the messages in this conversation.
///
/// This can be used to add the resulting values to the messages.
pub fn messages_mut(&mut self) -> impl Iterator<Item = &mut PamMessage> {
self.messages.iter_mut()
}
}

pub trait Converser {
/// Handle all the message in the given conversation. They may all be
/// handled in sequence or at the same time if possible.
fn handle_conversation(&self, conversation: &mut Conversation) -> PamResult<()>;
}

pub trait SequentialConverser: Converser {
/// Handle a normal prompt, i.e. present some message and ask for a value.
/// The value is not considered a secret.
fn handle_normal_prompt(&self, msg: &str) -> PamResult<PamBuffer>;
Expand All @@ -93,31 +53,19 @@ pub trait SequentialConverser: Converser {
fn handle_info(&self, msg: &str) -> PamResult<()>;
}

impl<T> Converser for T
where
T: SequentialConverser,
{
fn handle_conversation(&self, conversation: &mut Conversation) -> PamResult<()> {
use PamMessageStyle::*;

for msg in conversation.messages_mut() {
match msg.style {
PromptEchoOn => {
msg.set_response(self.handle_normal_prompt(&msg.msg)?);
}
PromptEchoOff => {
msg.set_response(self.handle_hidden_prompt(&msg.msg)?);
}
ErrorMessage => {
self.handle_error(&msg.msg)?;
}
TextInfo => {
self.handle_info(&msg.msg)?;
}
}
}
/// Handle a single message in a conversation.
fn handle_message<C: Converser>(
converser: &C,
style: PamMessageStyle,
msg: &str,
) -> PamResult<Option<PamBuffer>> {
use PamMessageStyle::*;

Ok(())
match style {
PromptEchoOn => converser.handle_normal_prompt(msg).map(Some),
PromptEchoOff => converser.handle_hidden_prompt(msg).map(Some),
ErrorMessage => converser.handle_error(msg).map(|()| None),
TextInfo => converser.handle_info(msg).map(|()| None),
}
}

Expand All @@ -142,7 +90,7 @@ impl CLIConverser {
}
}

impl SequentialConverser for CLIConverser {
impl Converser for CLIConverser {
fn handle_normal_prompt(&self, msg: &str) -> PamResult<PamBuffer> {
if self.no_interact {
return Err(PamError::InteractionRequired);
Expand Down Expand Up @@ -203,16 +151,14 @@ pub(super) unsafe extern "C" fn converse<C: Converser>(
appdata_ptr: *mut libc::c_void,
) -> libc::c_int {
let result = std::panic::catch_unwind(|| {
// convert the input messages to Rust types
let mut conversation = Conversation {
messages: Vec::with_capacity(num_msg as usize),
};
for i in 0..num_msg as isize {
let mut resp_bufs = Vec::with_capacity(num_msg as usize);
for i in 0..num_msg as usize {
// convert the input messages to Rust types
// SAFETY: the PAM contract ensures that `num_msg` does not exceed the amount
// of messages presented to this function in `msg`, and that it is not being
// written to at the same time as we are reading it. Note that the reference
// we create does not escape this loopy body.
let message: &pam_message = unsafe { &**msg.offset(i) };
let message: &pam_message = unsafe { &**msg.add(i) };

// SAFETY: PAM ensures that the messages passed are properly null-terminated
let msg = unsafe { string_from_ptr(message.msg) };
Expand All @@ -223,26 +169,17 @@ pub(super) unsafe extern "C" fn converse<C: Converser>(
return PamErrorType::ConversationError;
};

conversation.messages.push(PamMessage {
msg,
style,
response: None,
});
}
// send the conversation off to the Rust part
// SAFETY: appdata_ptr contains the `*mut ConverserData` that is untouched by PAM
let app_data = unsafe { &mut *(appdata_ptr as *mut ConverserData<C>) };
let Ok(resp_buf) = handle_message(&app_data.converser, style, &msg) else {
return PamErrorType::ConversationError;
};

// send the conversation of to the Rust part
// SAFETY: appdata_ptr contains the `*mut ConverserData` that is untouched by PAM
let app_data = unsafe { &mut *(appdata_ptr as *mut ConverserData<C>) };
if app_data
.converser
.handle_conversation(&mut conversation)
.is_err()
{
return PamErrorType::ConversationError;
resp_bufs.push(resp_buf);
}

// Conversation should now contain response messages
// allocate enough memory for the responses, set it to zero
// Allocate enough memory for the responses, which are initialized with zero.
// SAFETY: this will either allocate the required amount of (initialized) bytes,
// or return a null pointer.
let temp_resp = unsafe {
Expand All @@ -256,13 +193,13 @@ pub(super) unsafe extern "C" fn converse<C: Converser>(
}

// Store the responses
for (i, msg) in conversation.messages.into_iter().enumerate() {
for (i, resp_buf) in resp_bufs.into_iter().enumerate() {
// SAFETY: `i` will not exceed `num_msg` by the way `conversation_messages`
// is constructed, so `temp_resp` will have allocated-and-initialized data at
// the required offset that only we have a writable pointer to.
let response: &mut pam_response = unsafe { &mut *(temp_resp.add(i)) };

if let Some(secbuf) = msg.response {
if let Some(secbuf) = resp_buf {
response.resp = secbuf.leak().as_ptr().cast();
}
}
Expand Down Expand Up @@ -295,7 +232,12 @@ mod test {
use std::pin::Pin;
use PamMessageStyle::*;

impl SequentialConverser for String {
struct PamMessage {
msg: String,
style: PamMessageStyle,
}

impl Converser for String {
fn handle_normal_prompt(&self, msg: &str) -> PamResult<PamBuffer> {
Ok(PamBuffer::new(format!("{self} says {msg}").into_bytes()))
}
Expand Down Expand Up @@ -374,11 +316,7 @@ mod test {

fn msg(style: PamMessageStyle, msg: &str) -> PamMessage {
let msg = msg.to_string();
PamMessage {
style,
msg,
response: None,
}
PamMessage { style, msg }
}

// sanity check on the test cases; lib.rs is expected to manage the lifetime of the pointer
Expand Down

0 comments on commit 00d0d3d

Please sign in to comment.