From 7ad853b4687eaa3cc54b2cf3bf1e1a53eba076fc Mon Sep 17 00:00:00 2001 From: yngrtc Date: Thu, 25 Jan 2024 21:27:23 -0800 Subject: [PATCH] change payload in ApplicationMessage to DataChannelEvent with Open/Message/Close --- src/handlers/data/mod.rs | 78 +++++++++++++---------- src/handlers/gateway/mod.rs | 122 +++++++++++++++++++++++++----------- src/messages.rs | 11 +++- 3 files changed, 139 insertions(+), 72 deletions(-) diff --git a/src/handlers/data/mod.rs b/src/handlers/data/mod.rs index e5a5be7..66eb218 100644 --- a/src/handlers/data/mod.rs +++ b/src/handlers/data/mod.rs @@ -1,9 +1,9 @@ use crate::messages::{ - ApplicationMessage, DTLSMessageEvent, DataChannelMessage, DataChannelMessageParams, - DataChannelMessageType, MessageEvent, TaggedMessageEvent, + ApplicationMessage, DTLSMessageEvent, DataChannelEvent, DataChannelMessage, + DataChannelMessageParams, DataChannelMessageType, MessageEvent, TaggedMessageEvent, }; use data::message::{message_channel_ack::*, message_channel_open::*, message_type::*, *}; -use log::{debug, error}; +use log::{debug, error, warn}; use retty::channel::{Handler, InboundContext, InboundHandler, OutboundContext, OutboundHandler}; use shared::error::Result; use shared::marshal::*; @@ -49,7 +49,11 @@ impl InboundHandler for DataChannelInbound { let payload = Message::DataChannelAck(DataChannelAck {}).marshal()?; Ok(( - None, + Some(ApplicationMessage { + association_handle: message.association_handle, + stream_id: message.stream_id, + data_channel_event: DataChannelEvent::Open, + }), Some(DataChannelMessage { association_handle: message.association_handle, stream_id: message.stream_id, @@ -71,7 +75,7 @@ impl InboundHandler for DataChannelInbound { Some(ApplicationMessage { association_handle: message.association_handle, stream_id: message.stream_id, - payload: message.payload, + data_channel_event: DataChannelEvent::Message(message.payload), }), None, )) @@ -80,16 +84,7 @@ impl InboundHandler for DataChannelInbound { match try_read() { Ok((inbound_message, outbound_message)) => { - if let Some(application_message) = inbound_message { - debug!("recv application message {:?}", msg.transport.peer_addr); - ctx.fire_read(TaggedMessageEvent { - now: msg.now, - transport: msg.transport, - message: MessageEvent::DTLS(DTLSMessageEvent::APPLICATION( - application_message, - )), - }) - } + // first outbound message if let Some(data_channel_message) = outbound_message { debug!("send DataChannelAck message {:?}", msg.transport.peer_addr); ctx.fire_write(TaggedMessageEvent { @@ -100,6 +95,18 @@ impl InboundHandler for DataChannelInbound { )), }); } + + // then inbound message + if let Some(application_message) = inbound_message { + debug!("recv application message {:?}", msg.transport.peer_addr); + ctx.fire_read(TaggedMessageEvent { + now: msg.now, + transport: msg.transport, + message: MessageEvent::DTLS(DTLSMessageEvent::DATACHANNEL( + application_message, + )), + }) + } } Err(err) => { error!("try_read with error {}", err); @@ -119,28 +126,35 @@ impl OutboundHandler for DataChannelOutbound { type Wout = Self::Win; fn write(&mut self, ctx: &OutboundContext, msg: Self::Win) { - if let MessageEvent::DTLS(DTLSMessageEvent::APPLICATION(message)) = msg.message { + if let MessageEvent::DTLS(DTLSMessageEvent::DATACHANNEL(message)) = msg.message { debug!( "send application message {:?} with {:?}", msg.transport.peer_addr, message ); - ctx.fire_write(TaggedMessageEvent { - now: msg.now, - transport: msg.transport, - message: MessageEvent::DTLS(DTLSMessageEvent::SCTP(DataChannelMessage { - association_handle: message.association_handle, - stream_id: message.stream_id, - data_message_type: DataChannelMessageType::Text, - params: DataChannelMessageParams::Outbound { - ordered: true, - reliable: true, - max_rtx_count: 0, - max_rtx_millis: 0, - }, - payload: message.payload, - })), - }); + if let DataChannelEvent::Message(payload) = message.data_channel_event { + ctx.fire_write(TaggedMessageEvent { + now: msg.now, + transport: msg.transport, + message: MessageEvent::DTLS(DTLSMessageEvent::SCTP(DataChannelMessage { + association_handle: message.association_handle, + stream_id: message.stream_id, + data_message_type: DataChannelMessageType::Text, + params: DataChannelMessageParams::Outbound { + ordered: true, + reliable: true, + max_rtx_count: 0, + max_rtx_millis: 0, + }, + payload, + })), + }); + } else { + warn!( + "drop unsupported DATACHANNEL message {:?} to {}", + message, msg.transport.peer_addr + ); + } } else { // Bypass debug!("bypass DataChannel write {:?}", msg.transport.peer_addr); diff --git a/src/handlers/gateway/mod.rs b/src/handlers/gateway/mod.rs index e4e495c..a4f9c17 100644 --- a/src/handlers/gateway/mod.rs +++ b/src/handlers/gateway/mod.rs @@ -1,6 +1,6 @@ use crate::messages::{ - ApplicationMessage, DTLSMessageEvent, MessageEvent, RTPMessageEvent, STUNMessageEvent, - TaggedMessageEvent, + ApplicationMessage, DTLSMessageEvent, DataChannelEvent, MessageEvent, RTPMessageEvent, + STUNMessageEvent, TaggedMessageEvent, }; use crate::server::endpoint::{candidate::Candidate, Endpoint}; use crate::server::session::description::sdp_type::RTCSdpType; @@ -52,44 +52,34 @@ impl InboundHandler for GatewayInbound { type Rout = Self::Rin; fn read(&mut self, ctx: &InboundContext, msg: Self::Rin) { - let try_read = || -> Result>> { + let try_read = || -> Result> { match msg.message { MessageEvent::STUN(STUNMessageEvent::STUN(message)) => { - Ok(Some(vec![self.handle_stun_message( - msg.now, - msg.transport, - message, - )?])) + self.handle_stun_message(msg.now, msg.transport, message) } - MessageEvent::DTLS(DTLSMessageEvent::APPLICATION(message)) => { - Ok(Some(vec![self.handle_dtls_message( - msg.now, - msg.transport, - message, - )?])) + MessageEvent::DTLS(DTLSMessageEvent::DATACHANNEL(message)) => { + self.handle_dtls_message(msg.now, msg.transport, message) + } + MessageEvent::RTP(RTPMessageEvent::RTP(message)) => { + self.handle_rtp_message(msg.now, msg.transport, message) + } + MessageEvent::RTP(RTPMessageEvent::RTCP(message)) => { + self.handle_rtcp_message(msg.now, msg.transport, message) } - MessageEvent::RTP(RTPMessageEvent::RTP(message)) => Ok(Some( - self.handle_rtp_message(msg.now, msg.transport, message)?, - )), - MessageEvent::RTP(RTPMessageEvent::RTCP(message)) => Ok(Some( - self.handle_rtcp_message(msg.now, msg.transport, message)?, - )), _ => { warn!( "drop unsupported message {:?} from {}", msg.message, msg.transport.peer_addr ); - Ok(None) + Ok(vec![]) } } }; match try_read() { Ok(messages) => { - if let Some(messages) = messages { - for message in messages { - ctx.fire_write(message); - } + for message in messages { + ctx.fire_write(message); } } Err(err) => { @@ -138,7 +128,7 @@ impl GatewayInbound { now: Instant, transport_context: TransportContext, mut request: stun::message::Message, - ) -> Result { + ) -> Result> { let candidate = match self.check_stun_message(&mut request)? { Some(candidate) => candidate, None => { @@ -174,20 +164,71 @@ impl GatewayInbound { transport_context.peer_addr.port() ); - Ok(TaggedMessageEvent { + Ok(vec![TaggedMessageEvent { now, transport: transport_context, message: MessageEvent::STUN(STUNMessageEvent::STUN(response)), - }) + }]) } fn handle_dtls_message( &mut self, now: Instant, transport_context: TransportContext, - mut message: ApplicationMessage, - ) -> Result { - let request_sdp_str = String::from_utf8(message.payload.to_vec())?; + message: ApplicationMessage, + ) -> Result> { + match message.data_channel_event { + DataChannelEvent::Open => self.handle_datachannel_open( + now, + transport_context, + message.association_handle, + message.stream_id, + ), + DataChannelEvent::Message(payload) => self.handle_datachannel_message( + now, + transport_context, + message.association_handle, + message.stream_id, + payload, + ), + DataChannelEvent::Close => self.handle_datachannel_close( + now, + transport_context, + message.association_handle, + message.stream_id, + ), + } + } + + fn handle_datachannel_open( + &mut self, + _now: Instant, + _transport_context: TransportContext, + _association_handle: usize, + _stream_id: u16, + ) -> Result> { + Ok(vec![]) + } + + fn handle_datachannel_close( + &mut self, + _now: Instant, + _transport_context: TransportContext, + _association_handle: usize, + _stream_id: u16, + ) -> Result> { + Ok(vec![]) + } + + fn handle_datachannel_message( + &mut self, + now: Instant, + transport_context: TransportContext, + association_handle: usize, + stream_id: u16, + payload: BytesMut, + ) -> Result> { + let request_sdp_str = String::from_utf8(payload.to_vec())?; info!("handle_dtls_message: request_sdp {}", request_sdp_str); let request_sdp = serde_json::from_str::(&request_sdp_str) @@ -222,13 +263,18 @@ impl GatewayInbound { let response_sdp_str = serde_json::to_string(&response_sdp).map_err(|err| Error::Other(err.to_string()))?; info!("handle_dtls_message: response_sdp {}", response_sdp_str); - message.payload = BytesMut::from(response_sdp_str.as_str()); - Ok(TaggedMessageEvent { + Ok(vec![TaggedMessageEvent { now, transport: transport_context, - message: MessageEvent::DTLS(DTLSMessageEvent::APPLICATION(message)), - }) + message: MessageEvent::DTLS(DTLSMessageEvent::DATACHANNEL(ApplicationMessage { + association_handle, + stream_id, + data_channel_event: DataChannelEvent::Message(BytesMut::from( + response_sdp_str.as_str(), + )), + })), + }]) } fn handle_rtp_message( @@ -363,7 +409,7 @@ impl GatewayInbound { now: Instant, transport_context: TransportContext, transaction_id: TransactionId, - ) -> Result { + ) -> Result> { let mut response = stun::message::Message::new(); response.build(&[ Box::new(BINDING_SUCCESS), @@ -379,11 +425,11 @@ impl GatewayInbound { response.typ ); - Ok(TaggedMessageEvent { + Ok(vec![TaggedMessageEvent { now, transport: transport_context, message: MessageEvent::STUN(STUNMessageEvent::STUN(response)), - }) + }]) } fn add_endpoint( diff --git a/src/messages.rs b/src/messages.rs index 80df491..fe6f0d1 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -23,6 +23,13 @@ pub(crate) enum DataChannelMessageParams { }, } +#[derive(Debug, Clone, Eq, PartialEq)] +pub(crate) enum DataChannelEvent { + Open, + Message(BytesMut), + Close, +} + #[derive(Debug)] pub struct DataChannelMessage { pub(crate) association_handle: usize, @@ -36,7 +43,7 @@ pub struct DataChannelMessage { pub struct ApplicationMessage { pub(crate) association_handle: usize, pub(crate) stream_id: u16, - pub(crate) payload: BytesMut, + pub(crate) data_channel_event: DataChannelEvent, } #[derive(Debug)] @@ -49,7 +56,7 @@ pub enum STUNMessageEvent { pub enum DTLSMessageEvent { RAW(BytesMut), SCTP(DataChannelMessage), - APPLICATION(ApplicationMessage), + DATACHANNEL(ApplicationMessage), } #[derive(Debug)]