Skip to content

Commit

Permalink
Safer wrapping for FFI clients.
Browse files Browse the repository at this point in the history
Use `Arc` instead of `Box` to wrap the `ClientAdapter`, so that the lifetime of the adapter will be extended even in cases where the client is closed during commands.
This also saves on using `as usize` conversions to the pointer, which cause provenance to be lost and might lead to undefined behavior.

Signed-off-by: Shachar Langbeheim <[email protected]>
  • Loading branch information
nihohit committed Dec 31, 2024
1 parent abec885 commit 47b8318
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 27 deletions.
36 changes: 23 additions & 13 deletions csharp/lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use redis::{FromRedisValue, RedisResult};
use std::{
ffi::{c_void, CStr, CString},
os::raw::c_char,
sync::Arc,
};
use tokio::runtime::Builder;
use tokio::runtime::Runtime;
Expand Down Expand Up @@ -79,38 +80,46 @@ pub extern "C" fn create_client(
) -> *const c_void {
match create_client_internal(host, port, use_tls, success_callback, failure_callback) {
Err(_) => std::ptr::null(), // TODO - log errors
Ok(client) => Box::into_raw(Box::new(client)) as *const c_void,
Ok(client) => Arc::into_raw(Arc::new(client)) as *const c_void,
}
}

#[no_mangle]
pub extern "C" fn close_client(client_ptr: *const c_void) {
let client_ptr = unsafe { Box::from_raw(client_ptr as *mut Client) };
let _runtime_handle = client_ptr.runtime.enter();
drop(client_ptr);
let count = Arc::strong_count(&unsafe { Arc::from_raw(client_ptr as *mut Client) });
assert!(count == 1, "Client is still in use.");
}

/// Expects that key and value will be kept valid until the callback is called.
///
/// # Safety
///
/// This function should only be called, and should complete
/// and call one of the response callbacks before `close_client` is called. After
/// that the `client_ptr` is not in a valid state.
#[no_mangle]
pub extern "C" fn command(
pub unsafe extern "C" fn command(
client_ptr: *const c_void,
callback_index: usize,
request_type: RequestType,
args: *const *mut c_char,
arg_count: u32,
) {
let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) };
let client = unsafe {
// we increment the strong count to ensure that the client is not dropped just because we turned it into an Arc.
Arc::increment_strong_count(client_ptr);
Arc::from_raw(client_ptr as *mut Client)
};
let core_client_clone = client.clone();

// The safety of these needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed.
let ptr_address = client_ptr as usize;
let args_address = args as usize;

let mut client_clone = client.client.clone();
client.runtime.spawn(async move {
let Some(mut cmd) = request_type.get_command() else {
unsafe {
let client = Box::leak(Box::from_raw(ptr_address as *mut Client));
(client.failure_callback)(callback_index); // TODO - report errors
(core_client_clone.failure_callback)(callback_index); // TODO - report errors
return;
}
};
Expand All @@ -128,11 +137,12 @@ pub extern "C" fn command(
.await
.and_then(Option::<CString>::from_owned_redis_value);
unsafe {
let client = Box::leak(Box::from_raw(ptr_address as *mut Client));
match result {
Ok(None) => (client.success_callback)(callback_index, std::ptr::null()),
Ok(Some(c_str)) => (client.success_callback)(callback_index, c_str.as_ptr()),
Err(_) => (client.failure_callback)(callback_index), // TODO - report errors
Ok(None) => (core_client_clone.success_callback)(callback_index, std::ptr::null()),
Ok(Some(c_str)) => {
(core_client_clone.success_callback)(callback_index, c_str.as_ptr())
}
Err(_) => (core_client_clone.failure_callback)(callback_index), // TODO - report errors
};
}
});
Expand Down
34 changes: 20 additions & 14 deletions go/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use redis::cluster_routing::{
use redis::cluster_routing::{ResponsePolicy, Routable};
use redis::{Cmd, RedisResult, Value};
use std::slice::from_raw_parts;
use std::sync::Arc;
use std::{
ffi::{c_void, CString},
mem,
Expand Down Expand Up @@ -201,7 +202,7 @@ pub unsafe extern "C" fn create_client(
),
},
Ok(client) => ConnectionResponse {
conn_ptr: Box::into_raw(Box::new(client)) as *const c_void,
conn_ptr: Arc::into_raw(Arc::new(client)) as *const c_void,
connection_error_message: std::ptr::null(),
},
};
Expand All @@ -226,7 +227,9 @@ pub unsafe extern "C" fn create_client(
#[no_mangle]
pub unsafe extern "C" fn close_client(client_adapter_ptr: *const c_void) {
assert!(!client_adapter_ptr.is_null());
drop(unsafe { Box::from_raw(client_adapter_ptr as *mut ClientAdapter) });
let client_adapter = unsafe { Arc::from_raw(client_adapter_ptr as *mut ClientAdapter) };
let count = Arc::strong_count(&client_adapter);
assert!(count == 1, "Client is still in use.");
}

/// Deallocates a `ConnectionResponse`.
Expand Down Expand Up @@ -512,7 +515,8 @@ fn valkey_value_to_command_response(value: Value) -> RedisResult<CommandResponse
///
/// # Safety
///
/// * TODO: finish safety section.
/// This function should only be called, and should complete and call one of the response callbacks before `close_client` is called.
/// After `close_client` is called, the `client_ptr` is not in a valid state.
#[no_mangle]
pub unsafe extern "C" fn command(
client_adapter_ptr: *const c_void,
Expand All @@ -524,11 +528,13 @@ pub unsafe extern "C" fn command(
route_bytes: *const u8,
route_bytes_len: usize,
) {
let client_adapter =
unsafe { Box::leak(Box::from_raw(client_adapter_ptr as *mut ClientAdapter)) };
// The safety of this needs to be ensured by the calling code. Cannot dispose of the pointer before
// all operations have completed.
let ptr_address = client_adapter_ptr as usize;
let client_adapter = unsafe {
// we increment the strong count to ensure that the client is not dropped just because we turned it into an Arc.
Arc::increment_strong_count(client_adapter_ptr);
Arc::from_raw(client_adapter_ptr as *mut ClientAdapter)
};

let client_adapter_clone = client_adapter.clone();

let arg_vec =
unsafe { convert_double_pointer_to_vec(args as *const *const c_void, arg_count, args_len) };
Expand All @@ -552,7 +558,6 @@ pub unsafe extern "C" fn command(
let result = client_clone
.send_command(&cmd, get_route(route, Some(&cmd)))
.await;
let client_adapter = unsafe { Box::leak(Box::from_raw(ptr_address as *mut ClientAdapter)) };
let value = match result {
Ok(value) => value,
Err(err) => {
Expand All @@ -562,7 +567,7 @@ pub unsafe extern "C" fn command(
let c_err_str = CString::into_raw(
CString::new(message).expect("Couldn't convert error message to CString"),
);
unsafe { (client_adapter.failure_callback)(channel, c_err_str, error_type) };
unsafe { (client_adapter_clone.failure_callback)(channel, c_err_str, error_type) };
return;
}
};
Expand All @@ -571,17 +576,18 @@ pub unsafe extern "C" fn command(

unsafe {
match result {
Ok(message) => {
(client_adapter.success_callback)(channel, Box::into_raw(Box::new(message)))
}
Ok(message) => (client_adapter_clone.success_callback)(
channel,
Box::into_raw(Box::new(message)),
),
Err(err) => {
let message = errors::error_message(&err);
let error_type = errors::error_type(&err);

let c_err_str = CString::into_raw(
CString::new(message).expect("Couldn't convert error message to CString"),
);
(client_adapter.failure_callback)(channel, c_err_str, error_type);
(client_adapter_clone.failure_callback)(channel, c_err_str, error_type);
}
};
}
Expand Down

0 comments on commit 47b8318

Please sign in to comment.