diff --git a/dtls/src/conn/mod.rs b/dtls/src/conn/mod.rs index 75b38e37b..bcad96c5d 100644 --- a/dtls/src/conn/mod.rs +++ b/dtls/src/conn/mod.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use async_trait::async_trait; use log::*; use portable_atomic::{AtomicBool, AtomicU16}; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{mpsc, oneshot, Mutex}; use tokio::time::Duration; use util::replay_detector::*; use util::Conn; @@ -64,7 +64,8 @@ struct ConnReaderContext { cache: HandshakeCache, cipher_suite: Arc>>>, remote_epoch: Arc, - handshake_tx: mpsc::Sender>, + // use additional oneshot sender to mimic rendezvous channel behavior + handshake_tx: mpsc::Sender<(oneshot::Sender<()>, mpsc::Sender<()>)>, handshake_done_rx: mpsc::Receiver<()>, packet_tx: Arc>, } @@ -96,7 +97,8 @@ pub struct DTLSConn { pub(crate) flights: Option>, pub(crate) cfg: HandshakeConfig, pub(crate) retransmit: bool, - pub(crate) handshake_rx: mpsc::Receiver>, + // use additional oneshot sender to mimic rendezvous channel behavior + pub(crate) handshake_rx: mpsc::Receiver<(oneshot::Sender<()>, mpsc::Sender<()>)>, pub(crate) packet_tx: Arc>, pub(crate) handle_queue_tx: mpsc::Sender>, @@ -830,9 +832,13 @@ impl DTLSConn { if has_handshake { let (done_tx, mut done_rx) = mpsc::channel(1); - + let rendezvous_at_handshake = async { + let (rendezvous_tx, rendezvous_rx) = oneshot::channel(); + _ = ctx.handshake_tx.send((rendezvous_tx, done_tx)).await; + rendezvous_rx.await + }; tokio::select! { - _ = ctx.handshake_tx.send(done_tx) => { + _ = rendezvous_at_handshake => { let mut wait_done_rx = true; while wait_done_rx{ tokio::select!{ diff --git a/dtls/src/handshaker.rs b/dtls/src/handshaker.rs index b6e1a9e2b..640df6c61 100644 --- a/dtls/src/handshaker.rs +++ b/dtls/src/handshaker.rs @@ -330,45 +330,46 @@ impl DTLSConn { loop { tokio::select! { - done = self.handshake_rx.recv() =>{ - if done.is_none() { + done_senders = self.handshake_rx.recv() =>{ + if done_senders.is_none() { trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight.to_string()); return Err(Error::ErrAlertFatalOrClose); - } - - //trace!("[handshake:{}] {} received handshake_rx", srv_cli_str(self.state.is_client), self.current_flight.to_string()); - let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await; - drop(done); - match result { - Err((alert, mut err)) => { - trace!("[handshake:{}] {} result alert:{:?}, err:{:?}", - srv_cli_str(self.state.is_client), - self.current_flight.to_string(), - alert, - err); - - if let Some(alert) = alert { - let alert_err = self.notify(alert.alert_level, alert.alert_description).await; - - if let Err(alert_err) = alert_err { - if err.is_some() { - err = Some(alert_err); + } else if let Some((rendezvous_tx, done_tx)) = done_senders { + rendezvous_tx.send(()).ok(); + //trace!("[handshake:{}] {} received handshake_rx", srv_cli_str(self.state.is_client), self.current_flight.to_string()); + let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await; + drop(done_tx); + match result { + Err((alert, mut err)) => { + trace!("[handshake:{}] {} result alert:{:?}, err:{:?}", + srv_cli_str(self.state.is_client), + self.current_flight.to_string(), + alert, + err); + + if let Some(alert) = alert { + let alert_err = self.notify(alert.alert_level, alert.alert_description).await; + + if let Err(alert_err) = alert_err { + if err.is_some() { + err = Some(alert_err); + } } } + if let Some(err) = err { + return Err(err); + } } - if let Some(err) = err { - return Err(err); - } - } - Ok(next_flight) => { - trace!("[handshake:{}] {} -> {}", srv_cli_str(self.state.is_client), self.current_flight.to_string(), next_flight.to_string()); - if next_flight.is_last_recv_flight() && self.current_flight.to_string() == next_flight.to_string() { - return Ok(HandshakeState::Finished); + Ok(next_flight) => { + trace!("[handshake:{}] {} -> {}", srv_cli_str(self.state.is_client), self.current_flight.to_string(), next_flight.to_string()); + if next_flight.is_last_recv_flight() && self.current_flight.to_string() == next_flight.to_string() { + return Ok(HandshakeState::Finished); + } + self.current_flight = next_flight; + return Ok(HandshakeState::Preparing); } - self.current_flight = next_flight; - return Ok(HandshakeState::Preparing); - } - }; + }; + } } _ = retransmit_timer.as_mut() =>{