Skip to content

Commit

Permalink
server: Respond WS text request as text as well
Browse files Browse the repository at this point in the history
Make the `WS` server to respond to requests in text when they arrive
in text format instead of always replying in binary.
  • Loading branch information
jsdanielh authored and hrxi committed Feb 27, 2024
1 parent 6fa5847 commit 5949c12
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 27 deletions.
2 changes: 1 addition & 1 deletion derive/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ fn impl_service(im: &mut ItemImpl, args: &ServiceMeta) -> TokenStream {
async fn dispatch(
&mut self,
request: ::nimiq_jsonrpc_core::Request,
tx: Option<&::tokio::sync::mpsc::Sender<::std::vec::Vec<u8>>>,
tx: Option<&::tokio::sync::mpsc::Sender<::nimiq_jsonrpc_server::Message>>,
stream_id: u64,
) -> Option<::nimiq_jsonrpc_core::Response> {
match request.method.as_str() {
Expand Down
58 changes: 32 additions & 26 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use serde::{de::Deserialize, ser::Serialize};
use serde_json::Value;
use thiserror::Error;
use tokio::sync::{mpsc, RwLock, RwLockReadGuard, RwLockWriteGuard};
pub use warp::filters::ws::Message;
use warp::Filter;

use nimiq_jsonrpc_core::{
Expand All @@ -45,7 +46,7 @@ pub enum Error {

/// Error from the message queues, that are used internally.
#[error("Queue error: {0}")]
Mpsc(#[from] tokio::sync::mpsc::error::SendError<Vec<u8>>),
Mpsc(#[from] tokio::sync::mpsc::error::SendError<Message>),

/// JSON error
#[error("JSON error: {0}")]
Expand Down Expand Up @@ -150,14 +151,14 @@ impl<D: Dispatcher> Server<D> {
.and_then(move |body: Bytes| {
let inner = Arc::clone(&inner);
async move {
let data = Self::handle_raw_request(inner, &body, None)
let data = Self::handle_raw_request(inner, &Message::binary(body), None)
.await
.unwrap_or_default();
.unwrap_or(Message::binary([]));

let response = http::response::Builder::new()
.status(200)
.header("Content-Type", "application/json-rpc")
.body(data)
.body(data.as_bytes().to_owned())
.unwrap(); // As long as the hard-coded status code and content-type is correct, this won't fail.

Ok::<_, warp::Rejection>(response)
Expand Down Expand Up @@ -219,7 +220,7 @@ impl<D: Dispatcher> Server<D> {
// Forwards multiplexer queue output to websocket
let forward_fut = async move {
while let Some(data) = multiplex_rx.recv().await {
tx.send(warp::ws::Message::binary(data)).await?;
tx.send(data).await?;
}
Ok::<(), Error>(())
};
Expand All @@ -235,22 +236,17 @@ impl<D: Dispatcher> Server<D> {
if let Some((code, reason)) = message.close_frame() {
// If the close message contains a code and a reason, we need to echo it back
multiplex_tx
.send(
warp::ws::Message::close_with(code, reason.to_owned())
.into_bytes(),
)
.send(warp::ws::Message::close_with(code, reason.to_owned()))
.await?;
} else {
// Otherwise we echo an empty close message
multiplex_tx
.send(warp::ws::Message::close().into_bytes())
.await?;
multiplex_tx.send(warp::ws::Message::close()).await?;
}
// Then we exit the loop which closes the connection
break;
} else if let Some(response) = Self::handle_raw_request(
Arc::clone(&inner),
message.as_bytes(),
&message,
Some(&multiplex_tx),
)
.await
Expand Down Expand Up @@ -281,10 +277,10 @@ impl<D: Dispatcher> Server<D> {
///
async fn handle_raw_request(
inner: Arc<Inner<D>>,
request: &[u8],
tx: Option<&mpsc::Sender<Vec<u8>>>,
) -> Option<Vec<u8>> {
match serde_json::from_slice(request) {
request: &Message,
tx: Option<&mpsc::Sender<Message>>,
) -> Option<Message> {
match serde_json::from_slice(request.as_bytes()) {
Ok(request) => Self::handle_request(inner, request, tx).await,
Err(_e) => {
log::error!("Received invalid JSON from client");
Expand All @@ -295,7 +291,16 @@ impl<D: Dispatcher> Server<D> {
}
}
.map(|response| {
serde_json::to_vec(&response).expect("Failed to serialize JSON RPC response")
if request.is_text() {
Message::text(
serde_json::to_string(&response)
.expect("Failed to serialize JSON RPC response"),
)
} else {
Message::binary(
serde_json::to_vec(&response).expect("Failed to serialize JSON RPC response"),
)
}
})
}

Expand All @@ -311,7 +316,7 @@ impl<D: Dispatcher> Server<D> {
async fn handle_request(
inner: Arc<Inner<D>>,
request: SingleOrBatch<Request>,
tx: Option<&mpsc::Sender<Vec<u8>>>,
tx: Option<&mpsc::Sender<Message>>,
) -> Option<SingleOrBatch<Response>> {
match request {
SingleOrBatch::Single(request) => Self::handle_single_request(inner, request, tx)
Expand Down Expand Up @@ -342,7 +347,7 @@ impl<D: Dispatcher> Server<D> {
async fn handle_single_request(
inner: Arc<Inner<D>>,
request: Request,
tx: Option<&mpsc::Sender<Vec<u8>>>,
tx: Option<&mpsc::Sender<Message>>,
) -> Option<Response> {
let mut dispatcher = inner.dispatcher.write().await;
// This ID is only used for streams
Expand All @@ -366,7 +371,7 @@ pub trait Dispatcher: Send + Sync + 'static {
async fn dispatch(
&mut self,
request: Request,
tx: Option<&mpsc::Sender<Vec<u8>>>,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
) -> Option<Response>;

Expand Down Expand Up @@ -406,7 +411,7 @@ impl Dispatcher for ModularDispatcher {
async fn dispatch(
&mut self,
request: Request,
tx: Option<&mpsc::Sender<Vec<u8>>>,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
) -> Option<Response> {
for dispatcher in &mut self.dispatchers {
Expand Down Expand Up @@ -475,7 +480,7 @@ where
async fn dispatch(
&mut self,
request: Request,
tx: Option<&mpsc::Sender<Vec<u8>>>,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
) -> Option<Response> {
if self.is_allowed(&request.method) {
Expand Down Expand Up @@ -628,7 +633,7 @@ pub fn method_not_found(request: Request) -> Option<Response> {

async fn forward_notification<T>(
item: T,
tx: &mut mpsc::Sender<Vec<u8>>,
tx: &mut mpsc::Sender<Message>,
id: &SubscriptionId,
method: &str,
) -> Result<(), Error>
Expand All @@ -644,7 +649,8 @@ where

log::debug!("Sending notification: {:?}", notification);

tx.send(serde_json::to_vec(&notification)?).await?;
tx.send(Message::binary(serde_json::to_vec(&notification)?))
.await?;

Ok(())
}
Expand All @@ -664,7 +670,7 @@ where
///
pub fn connect_stream<T, S>(
stream: S,
tx: &mpsc::Sender<Vec<u8>>,
tx: &mpsc::Sender<Message>,
stream_id: u64,
method: String,
) -> SubscriptionId
Expand Down

0 comments on commit 5949c12

Please sign in to comment.