diff --git a/crates/jmux-proto/src/lib.rs b/crates/jmux-proto/src/lib.rs index 260761937..f52b66364 100644 --- a/crates/jmux-proto/src/lib.rs +++ b/crates/jmux-proto/src/lib.rs @@ -1,4 +1,6 @@ -//! [Specification document](https://github.com/Devolutions/devolutions-gateway/blob/master/crates/jmux-proto/spec/JMUX_Spec.md) +//! [Specification document][source] +//! +//! [source]: https://github.com/Devolutions/devolutions-gateway/blob/master/docs/JMUX-spec.md use bytes::{Buf as _, BufMut as _}; use core::fmt; @@ -482,7 +484,7 @@ pub struct ChannelOpen { impl ChannelOpen { pub const NAME: &'static str = "CHANNEL OPEN"; - pub const DEFAULT_INITIAL_WINDOW_SIZE: u32 = 32_768; + pub const DEFAULT_INITIAL_WINDOW_SIZE: u32 = 64 * 1024 * 1024; // 64 MiB pub const FIXED_PART_SIZE: usize = 4 /* senderChannelId */ + 4 /* initialWindowSize */ + 2 /* maximumPacketSize */; pub fn new(id: LocalChannelId, maximum_packet_size: u16, destination_url: DestinationUrl) -> Self { diff --git a/crates/jmux-proxy/src/codec.rs b/crates/jmux-proxy/src/codec.rs index 4db047220..5d2082eaa 100644 --- a/crates/jmux-proxy/src/codec.rs +++ b/crates/jmux-proxy/src/codec.rs @@ -12,7 +12,7 @@ impl Decoder for JmuxCodec { type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - const MAX_RESERVE_CHUNK_IN_BYTES: usize = 8 * 1024; + const MAX_RESERVE_CHUNK_IN_BYTES: usize = 8 * 1024; // 8 kiB if src.len() < Header::SIZE { // Not enough data to read length marker. diff --git a/crates/jmux-proxy/src/lib.rs b/crates/jmux-proxy/src/lib.rs index 1359d13be..a32d1ee2a 100644 --- a/crates/jmux-proxy/src/lib.rs +++ b/crates/jmux-proxy/src/lib.rs @@ -34,13 +34,14 @@ use tracing::{Instrument as _, Span}; // but we need to wait until 2025 before making this change. // // iperf result for 4 * 1024: -// > 0.0000-19.4523 sec 26.6 GBytes 11.7 Gbits/sec +// > 0.0000-10.0490 sec 23.0 GBytes 19.7 Gbits/sec // // iperf result for 16 * 1024: -// > 0.0000-13.8540 sec 33.6 GBytes 20.8 Gbits/sec +// > 0.0000-10.0393 sec 30.6 GBytes 26.2 Gbits/sec // -// This is an improvement of 77.7%. -const MAXIMUM_PACKET_SIZE_IN_BYTES: u16 = 4 * 1024; +// This is an improvement of ~32.9%. +const MAXIMUM_PACKET_SIZE_IN_BYTES: u16 = 4 * 1024; // 4 kiB +const WINDOW_ADJUSTMENT_THRESHOLD: u32 = 4 * 1024; // 4 kiB pub type ApiResponseSender = oneshot::Sender; pub type ApiResponseReceiver = oneshot::Receiver; @@ -173,6 +174,7 @@ struct JmuxChannelCtx { initial_window_size: u32, window_size_updated: Arc, window_size: Arc, + remote_window_size: u32, maximum_packet_size: u16, @@ -280,9 +282,6 @@ impl JmuxSenderTask { } } - // TODO: send a signal to the main scheduler when we are done processing channel data messages - // and adjust windows for all the channels only then. - info!("Closing JMUX sender task..."); jmux_writer.flush().await?; @@ -330,6 +329,8 @@ async fn scheduler_task_impl(task: JmuxSc const MAX_CONSECUTIVE_PIPE_FAILURES: u8 = 5; let mut nb_consecutive_pipe_failures = 0; + let mut needs_window_adjustment = false; + loop { // NOTE: Current task is the "jmux scheduler" or "jmux orchestrator". // It handles the JMUX context and communicates with other tasks. @@ -557,6 +558,7 @@ async fn scheduler_task_impl(task: JmuxSc initial_window_size: msg.initial_window_size, window_size_updated: Arc::clone(&window_size_updated), window_size: Arc::clone(&window_size), + remote_window_size: msg.initial_window_size, maximum_packet_size: msg.maximum_packet_size, @@ -576,12 +578,9 @@ async fn scheduler_task_impl(task: JmuxSc let local_id = LocalChannelId::from(msg.recipient_channel_id); let peer_id = DistantChannelId::from(msg.sender_channel_id); - let (destination_url, api_response_tx) = match pending_channels.remove(&local_id) { - Some(pending) => pending, - None => { - warn!("Couldn’t find pending channel for {}", local_id); - continue; - }, + let Some((destination_url, api_response_tx)) = pending_channels.remove(&local_id) else { + warn!(channel.id = %local_id, "Couldn’t find pending channel"); + continue; }; let channel_span = info_span!(parent: parent_span.clone(), "channel", %local_id, %peer_id, url = %destination_url).entered(); @@ -603,6 +602,7 @@ async fn scheduler_task_impl(task: JmuxSc initial_window_size: msg.initial_window_size, window_size_updated: Arc::new(Notify::new()), window_size: Arc::new(AtomicUsize::new(usize::try_from(msg.initial_window_size).expect("u32-to-usize"))), + remote_window_size: msg.initial_window_size, maximum_packet_size: msg.maximum_packet_size, @@ -610,41 +610,43 @@ async fn scheduler_task_impl(task: JmuxSc })?; } Message::WindowAdjust(msg) => { - if let Some(ctx) = jmux_ctx.get_channel_mut(LocalChannelId::from(msg.recipient_channel_id)) { - ctx.window_size.fetch_add(usize::try_from(msg.window_adjustment).expect("u32-to-usize"), Ordering::SeqCst); - ctx.window_size_updated.notify_one(); - } + let id = LocalChannelId::from(msg.recipient_channel_id); + let Some(channel) = jmux_ctx.get_channel_mut(id) else { + warn!(channel.id = %id, "Couldn’t find channel"); + continue; + }; + + channel.window_size.fetch_add(usize::try_from(msg.window_adjustment).expect("u32-to-usize"), Ordering::SeqCst); + channel.window_size_updated.notify_one(); } Message::Data(msg) => { let id = LocalChannelId::from(msg.recipient_channel_id); - let data_length = u16::try_from(msg.transfer_data.len()).expect("header.size (u16) <= u16::MAX"); - let channel = match jmux_ctx.get_channel(id) { - Some(channel) => channel, - None => { - warn!(channel.id = %id, "Couldn’t find channel"); - continue; - }, + let Some(channel) = jmux_ctx.get_channel_mut(id) else { + warn!(channel.id = %id, "Couldn’t find channel"); + continue; }; - let data_tx = match data_senders.get_mut(&id) { - Some(sender) => sender, - None => { - warn!(channel.id = %id, "Received data but associated data sender is missing"); - continue; - } - }; + let payload_size = u32::try_from(msg.transfer_data.len()).expect("packet length is found by decoding a u16 in decoder"); + channel.remote_window_size = channel.remote_window_size.saturating_sub(payload_size); - if channel.maximum_packet_size < data_length { - warn!(channel.id = %id, "Packet's size is exceeding the maximum size for this channel and was dropped"); + let packet_size = Header::SIZE + msg.size(); + if usize::from(channel.maximum_packet_size) < packet_size { + channel.span.in_scope(|| { + warn!(packet_size, "Packet's size is exceeding the maximum size for this channel and was dropped"); + }); continue; } + let Some(data_tx) = data_senders.get_mut(&id) else { + channel.span.in_scope(|| { + warn!("Received data but associated data sender is missing"); + }); + continue; + }; + let _ = data_tx.send(msg.transfer_data); - // Simplest flow control logic for now: just send back a WINDOW ADJUST message to - // increase back peer’s window size. - msg_to_send_tx.send(Message::window_adjust(channel.distant_id, u32::from(data_length))) - .context("couldn’t send WINDOW ADJUST message")?; + needs_window_adjustment = true; } Message::Eof(msg) => { // Per the spec: @@ -654,12 +656,9 @@ async fn scheduler_task_impl(task: JmuxSc // > This message does not consume window space and can be sent even if no window space is available. let id = LocalChannelId::from(msg.recipient_channel_id); - let channel = match jmux_ctx.get_channel_mut(id) { - Some(channel) => channel, - None => { - warn!("Couldn’t find channel with id {}", id); - continue; - }, + let Some(channel) = jmux_ctx.get_channel_mut(id) else { + warn!(channel.id = %id, "Couldn’t find channel"); + continue; }; channel.distant_state = JmuxChannelState::Eof; @@ -684,12 +683,9 @@ async fn scheduler_task_impl(task: JmuxSc Message::OpenFailure(msg) => { let id = LocalChannelId::from(msg.recipient_channel_id); - let (destination_url, api_response_tx) = match pending_channels.remove(&id) { - Some(pending) => pending, - None => { - warn!("Couldn’t find pending channel {}", id); - continue; - }, + let Some((destination_url, api_response_tx)) = pending_channels.remove(&id) else { + warn!(channel.id = %id, "Couldn’t find pending channel"); + continue; }; warn!(local_id = %id, %destination_url, %msg.reason_code, "Channel opening failed: {}", msg.description); @@ -698,12 +694,9 @@ async fn scheduler_task_impl(task: JmuxSc } Message::Close(msg) => { let local_id = LocalChannelId::from(msg.recipient_channel_id); - let channel = match jmux_ctx.get_channel_mut(local_id) { - Some(channel) => channel, - None => { - warn!("Couldn’t find channel with id {}", local_id); - continue; - }, + let Some(channel) = jmux_ctx.get_channel_mut(local_id) else { + warn!(channel.id = %local_id, "Couldn’t find channel"); + continue; }; let distant_id = channel.distant_id; let channel_span = channel.span.clone(); @@ -729,6 +722,25 @@ async fn scheduler_task_impl(task: JmuxSc } } } + _ = core::future::ready(()), if needs_window_adjustment => { + for channel in jmux_ctx.channels.values_mut() { + let window_adjustment = channel.initial_window_size - channel.remote_window_size; + + if window_adjustment > WINDOW_ADJUSTMENT_THRESHOLD { + channel.span.in_scope(|| { + trace!(%channel.distant_id, "Send WindowAdjust message"); + }); + + msg_to_send_tx + .send(Message::window_adjust(channel.distant_id, window_adjustment)) + .context("couldn’t send WINDOW ADJUST message")?; + + channel.remote_window_size = channel.initial_window_size; + } + } + + needs_window_adjustment = false; + } } } @@ -803,24 +815,25 @@ impl DataReaderTask { loop { let window_size_now = window_size.load(Ordering::SeqCst); + if window_size_now < chunk.len() { trace!( window_size_now, - full_packet_size = bytes.len(), - "Window size insufficient to send full packet. Truncate and wait." + chunk_length = chunk.len(), + "Window size insufficient to send full chunk. Truncate and wait." ); if window_size_now > 0 { - let bytes_to_send_now = chunk.split_to(window_size_now); - window_size.fetch_sub(bytes_to_send_now.len(), Ordering::SeqCst); + let to_send_now = chunk.split_to(window_size_now); + window_size.fetch_sub(to_send_now.len(), Ordering::SeqCst); msg_to_send_tx - .send(Message::data(distant_id, bytes_to_send_now.freeze())) + .send(Message::data(distant_id, to_send_now.freeze())) .context("couldn’t send DATA message")?; } window_size_updated.notified().await; } else { - window_size.fetch_sub(bytes.len(), Ordering::SeqCst); + window_size.fetch_sub(chunk.len(), Ordering::SeqCst); msg_to_send_tx .send(Message::data(distant_id, chunk.freeze())) .context("couldn’t send DATA message")?;