diff --git a/csharp/lib/src/lib.rs b/csharp/lib/src/lib.rs index c497410e31..bf76bd6e6e 100644 --- a/csharp/lib/src/lib.rs +++ b/csharp/lib/src/lib.rs @@ -7,6 +7,7 @@ use redis::{FromRedisValue, RedisResult}; use std::{ ffi::{c_void, CStr, CString}, os::raw::c_char, + sync::Arc, }; use tokio::runtime::Builder; use tokio::runtime::Runtime; @@ -79,38 +80,46 @@ pub extern "C" fn create_client( ) -> *const c_void { match create_client_internal(host, port, use_tls, success_callback, failure_callback) { Err(_) => std::ptr::null(), // TODO - log errors - Ok(client) => Box::into_raw(Box::new(client)) as *const c_void, + Ok(client) => Arc::into_raw(Arc::new(client)) as *const c_void, } } #[no_mangle] pub extern "C" fn close_client(client_ptr: *const c_void) { - let client_ptr = unsafe { Box::from_raw(client_ptr as *mut Client) }; - let _runtime_handle = client_ptr.runtime.enter(); - drop(client_ptr); + let count = Arc::strong_count(&unsafe { Arc::from_raw(client_ptr as *mut Client) }); + assert!(count == 1, "Client is still in use."); } /// Expects that key and value will be kept valid until the callback is called. +/// +/// # Safety +/// +/// This function should only be called, and should complete +/// and call one of the response callbacks before `close_client` is called. After +/// that the `client_ptr` is not in a valid state. #[no_mangle] -pub extern "C" fn command( +pub unsafe extern "C" fn command( client_ptr: *const c_void, callback_index: usize, request_type: RequestType, args: *const *mut c_char, arg_count: u32, ) { - let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) }; + let client = unsafe { + // we increment the strong count to ensure that the client is not dropped just because we turned it into an Arc. + Arc::increment_strong_count(client_ptr); + Arc::from_raw(client_ptr as *mut Client) + }; + let core_client_clone = client.clone(); // The safety of these needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed. - let ptr_address = client_ptr as usize; let args_address = args as usize; let mut client_clone = client.client.clone(); client.runtime.spawn(async move { let Some(mut cmd) = request_type.get_command() else { unsafe { - let client = Box::leak(Box::from_raw(ptr_address as *mut Client)); - (client.failure_callback)(callback_index); // TODO - report errors + (core_client_clone.failure_callback)(callback_index); // TODO - report errors return; } }; @@ -128,11 +137,12 @@ pub extern "C" fn command( .await .and_then(Option::::from_owned_redis_value); unsafe { - let client = Box::leak(Box::from_raw(ptr_address as *mut Client)); match result { - Ok(None) => (client.success_callback)(callback_index, std::ptr::null()), - Ok(Some(c_str)) => (client.success_callback)(callback_index, c_str.as_ptr()), - Err(_) => (client.failure_callback)(callback_index), // TODO - report errors + Ok(None) => (core_client_clone.success_callback)(callback_index, std::ptr::null()), + Ok(Some(c_str)) => { + (core_client_clone.success_callback)(callback_index, c_str.as_ptr()) + } + Err(_) => (core_client_clone.failure_callback)(callback_index), // TODO - report errors }; } }); diff --git a/go/src/lib.rs b/go/src/lib.rs index 376da58dfa..b9ed0a65b9 100644 --- a/go/src/lib.rs +++ b/go/src/lib.rs @@ -16,6 +16,7 @@ use redis::cluster_routing::{ use redis::cluster_routing::{ResponsePolicy, Routable}; use redis::{Cmd, RedisResult, Value}; use std::slice::from_raw_parts; +use std::sync::Arc; use std::{ ffi::{c_void, CString}, mem, @@ -201,7 +202,7 @@ pub unsafe extern "C" fn create_client( ), }, Ok(client) => ConnectionResponse { - conn_ptr: Box::into_raw(Box::new(client)) as *const c_void, + conn_ptr: Arc::into_raw(Arc::new(client)) as *const c_void, connection_error_message: std::ptr::null(), }, }; @@ -226,7 +227,9 @@ pub unsafe extern "C" fn create_client( #[no_mangle] pub unsafe extern "C" fn close_client(client_adapter_ptr: *const c_void) { assert!(!client_adapter_ptr.is_null()); - drop(unsafe { Box::from_raw(client_adapter_ptr as *mut ClientAdapter) }); + let client_adapter = unsafe { Arc::from_raw(client_adapter_ptr as *mut ClientAdapter) }; + let count = Arc::strong_count(&client_adapter); + assert!(count == 1, "Client is still in use."); } /// Deallocates a `ConnectionResponse`. @@ -512,7 +515,8 @@ fn valkey_value_to_command_response(value: Value) -> RedisResult value, Err(err) => { @@ -562,7 +567,7 @@ pub unsafe extern "C" fn command( let c_err_str = CString::into_raw( CString::new(message).expect("Couldn't convert error message to CString"), ); - unsafe { (client_adapter.failure_callback)(channel, c_err_str, error_type) }; + unsafe { (client_adapter_clone.failure_callback)(channel, c_err_str, error_type) }; return; } }; @@ -571,9 +576,10 @@ pub unsafe extern "C" fn command( unsafe { match result { - Ok(message) => { - (client_adapter.success_callback)(channel, Box::into_raw(Box::new(message))) - } + Ok(message) => (client_adapter_clone.success_callback)( + channel, + Box::into_raw(Box::new(message)), + ), Err(err) => { let message = errors::error_message(&err); let error_type = errors::error_type(&err); @@ -581,7 +587,7 @@ pub unsafe extern "C" fn command( let c_err_str = CString::into_raw( CString::new(message).expect("Couldn't convert error message to CString"), ); - (client_adapter.failure_callback)(channel, c_err_str, error_type); + (client_adapter_clone.failure_callback)(channel, c_err_str, error_type); } }; }