From 2eb24ab00aa70c70dc34ce9c457311528e4a13e9 Mon Sep 17 00:00:00 2001 From: yngrtc Date: Fri, 1 Mar 2024 15:19:01 -0800 Subject: [PATCH] refactor SctpHandler by moving SctpEndpoint to Endpoint --- examples/sfu_impl/mod.rs | 7 +- rtc | 2 +- src/endpoint/transport.rs | 23 ++++ src/handler/sctp.rs | 279 +++++++++++++++++++++----------------- src/server/states.rs | 14 -- 5 files changed, 180 insertions(+), 145 deletions(-) diff --git a/examples/sfu_impl/mod.rs b/examples/sfu_impl/mod.rs index abb5bd0..1e9afdc 100644 --- a/examples/sfu_impl/mod.rs +++ b/examples/sfu_impl/mod.rs @@ -106,8 +106,6 @@ pub fn run_sfu( rx: Receiver, server_config: Arc, ) -> anyhow::Result<()> { - let sctp_endpoint_config = Arc::new(sctp::EndpointConfig::default()); - let server_states = Rc::new(RefCell::new(ServerStates::new( server_config, socket.local_addr()?, @@ -121,7 +119,6 @@ pub fn run_sfu( socket.local_addr()?, outgoing_queue.clone(), server_states.clone(), - sctp_endpoint_config, ); let mut buf = vec![0; 2000]; @@ -212,7 +209,6 @@ fn build_pipeline( local_addr: SocketAddr, writer: Rc>>, server_states: Rc>, - sctp_endpoint_config: Arc, ) -> Rc> { let pipeline: Pipeline = Pipeline::new(); @@ -222,8 +218,7 @@ fn build_pipeline( let stun_handler = StunHandler::new(); // DTLS let dtls_handler = DtlsHandler::new(local_addr, Rc::clone(&server_states)); - let sctp_handler = - SctpHandler::new(local_addr, Rc::clone(&server_states), sctp_endpoint_config); + let sctp_handler = SctpHandler::new(local_addr, Rc::clone(&server_states)); let data_channel_handler = DataChannelHandler::new(); // SRTP let srtp_handler = SrtpHandler::new(Rc::clone(&server_states)); diff --git a/rtc b/rtc index 4b62d38..dd7992e 160000 --- a/rtc +++ b/rtc @@ -1 +1 @@ -Subproject commit 4b62d38d4c6600fe17f1ba1853721852f6309596 +Subproject commit dd7992ee4f83e207e6478c9b853f3544c38641b0 diff --git a/src/endpoint/transport.rs b/src/endpoint/transport.rs index 04e885d..adc0213 100644 --- a/src/endpoint/transport.rs +++ b/src/endpoint/transport.rs @@ -1,6 +1,8 @@ use crate::endpoint::candidate::Candidate; use crate::types::FourTuple; +use sctp::{Association, AssociationHandle}; use srtp::context::Context; +use std::collections::HashMap; use std::rc::Rc; use std::sync::Arc; @@ -15,6 +17,7 @@ pub(crate) struct Transport { // SCTP sctp_endpoint: sctp::Endpoint, + sctp_associations: HashMap, // DataChannel association_handle: Option, @@ -41,6 +44,7 @@ impl Transport { dtls_endpoint: dtls::endpoint::Endpoint::new(Some(dtls_handshake_config)), sctp_endpoint: sctp::Endpoint::new(sctp_endpoint_config, Some(sctp_server_config)), + sctp_associations: HashMap::new(), association_handle: None, stream_id: None, @@ -74,6 +78,25 @@ impl Transport { &self.sctp_endpoint } + pub(crate) fn get_mut_sctp_associations( + &mut self, + ) -> &mut HashMap { + &mut self.sctp_associations + } + + pub(crate) fn get_mut_sctp_endpoint_associations( + &mut self, + ) -> ( + &mut sctp::Endpoint, + &mut HashMap, + ) { + (&mut self.sctp_endpoint, &mut self.sctp_associations) + } + + pub(crate) fn get_sctp_associations(&self) -> &HashMap { + &self.sctp_associations + } + pub(crate) fn local_srtp_context(&mut self) -> Option<&mut Context> { self.local_srtp_context.as_mut() } diff --git a/src/handler/sctp.rs b/src/handler/sctp.rs index a9958bd..aa1569b 100644 --- a/src/handler/sctp.rs +++ b/src/handler/sctp.rs @@ -17,60 +17,49 @@ use std::collections::HashMap; use std::collections::VecDeque; use std::net::SocketAddr; use std::rc::Rc; -use std::sync::Arc; use std::time::Instant; struct SctpInbound { local_addr: SocketAddr, server_states: Rc>, - sctp_endpoint: Rc>, - transmits: VecDeque, internal_buffer: Vec, } struct SctpOutbound { local_addr: SocketAddr, server_states: Rc>, - sctp_endpoint: Rc>, - - transmits: VecDeque, } pub struct SctpHandler { sctp_inbound: SctpInbound, sctp_outbound: SctpOutbound, } +enum SctpMessage { + Inbound(DataChannelMessage), + Outbound(Transmit), +} + impl SctpHandler { - pub fn new( - local_addr: SocketAddr, - server_states: Rc>, - sctp_endpoint_config: Arc, - ) -> Self { - let sctp_server_config = - Arc::clone(&server_states.borrow().server_config().sctp_server_config); - let sctp_endpoint = Rc::new(RefCell::new(sctp::Endpoint::new( - sctp_endpoint_config, - Some(Arc::clone(&sctp_server_config)), - ))); + pub fn new(local_addr: SocketAddr, server_states: Rc>) -> Self { + let max_message_size = { + let server_states = server_states.borrow(); + server_states + .server_config() + .sctp_server_config + .transport + .max_message_size() as usize + }; SctpHandler { sctp_inbound: SctpInbound { local_addr, server_states: Rc::clone(&server_states), - sctp_endpoint: Rc::clone(&sctp_endpoint), - transmits: VecDeque::new(), - internal_buffer: vec![ - 0u8; - sctp_server_config.transport.max_message_size() as usize - ], + internal_buffer: vec![0u8; max_message_size], }, sctp_outbound: SctpOutbound { local_addr, server_states, - sctp_endpoint, - - transmits: VecDeque::new(), }, } } @@ -83,26 +72,25 @@ impl InboundHandler for SctpInbound { fn read(&mut self, ctx: &InboundContext, msg: Self::Rin) { if let MessageEvent::Dtls(DTLSMessageEvent::Raw(dtls_message)) = msg.message { debug!("recv sctp RAW {:?}", msg.transport.peer_addr); - let try_read = || -> Result> { - let handle_result = { - let mut sctp_endpoint = self.sctp_endpoint.borrow_mut(); - sctp_endpoint.handle( - msg.now, - msg.transport.peer_addr, - Some(msg.transport.local_addr.ip()), - msg.transport.ecn, - dtls_message.freeze(), //TODO: switch API Bytes to BytesMut - ) - }; + let four_tuple = (&msg.transport).into(); + let try_read = || -> Result> { let mut server_states = self.server_states.borrow_mut(); + let transport = server_states.get_mut_transport(&four_tuple)?; + let (sctp_endpoint, sctp_associations) = + transport.get_mut_sctp_endpoint_associations(); let mut sctp_events: HashMap> = HashMap::new(); - if let Some((ch, event)) = handle_result { + if let Some((ch, event)) = sctp_endpoint.handle( + msg.now, + msg.transport.peer_addr, + Some(msg.transport.local_addr.ip()), + msg.transport.ecn, + dtls_message.freeze(), //TODO: switch API Bytes to BytesMut + ) { match event { DatagramEvent::NewAssociation(conn) => { - let sctp_associations = server_states.get_mut_sctp_associations(); sctp_associations.insert(ch, conn); } DatagramEvent::AssociationEvent(event) => { @@ -115,7 +103,6 @@ impl InboundHandler for SctpInbound { { let mut endpoint_events: Vec<(AssociationHandle, EndpointEvent)> = vec![]; - let sctp_associations = server_states.get_mut_sctp_associations(); for (ch, conn) in sctp_associations.iter_mut() { for (event_ch, conn_events) in sctp_events.iter_mut() { if ch == event_ch { @@ -131,7 +118,7 @@ impl InboundHandler for SctpInbound { let mut stream = conn.stream(id)?; while let Some(chunks) = stream.read_sctp()? { let n = chunks.read(&mut self.internal_buffer)?; - messages.push(DataChannelMessage { + messages.push(SctpMessage::Inbound(DataChannelMessage { association_handle: ch.0, stream_id: id, data_message_type: to_data_message_type(chunks.ppi), @@ -139,7 +126,7 @@ impl InboundHandler for SctpInbound { seq_num: chunks.ssn, }, payload: BytesMut::from(&self.internal_buffer[0..n]), - }); + })); } } } @@ -149,11 +136,12 @@ impl InboundHandler for SctpInbound { } while let Some(x) = conn.poll_transmit(msg.now) { - self.transmits.extend(split_transmit(x)); + for transmit in split_transmit(x) { + messages.push(SctpMessage::Outbound(transmit)); + } } } - let mut sctp_endpoint = self.sctp_endpoint.borrow_mut(); for (ch, event) in endpoint_events { sctp_endpoint.handle_event(ch, event); // handle drain event sctp_associations.remove(&ch); @@ -165,15 +153,36 @@ impl InboundHandler for SctpInbound { match try_read() { Ok(messages) => { for message in messages { - debug!( - "recv sctp data channel message {:?}", - msg.transport.peer_addr - ); - ctx.fire_read(TaggedMessageEvent { - now: msg.now, - transport: msg.transport, - message: MessageEvent::Dtls(DTLSMessageEvent::Sctp(message)), - }) + match message { + SctpMessage::Inbound(message) => { + debug!( + "recv sctp data channel message {:?}", + msg.transport.peer_addr + ); + ctx.fire_read(TaggedMessageEvent { + now: msg.now, + transport: msg.transport, + message: MessageEvent::Dtls(DTLSMessageEvent::Sctp(message)), + }) + } + SctpMessage::Outbound(transmit) => { + if let Payload::RawEncode(raw_data) = transmit.payload { + for raw in raw_data { + ctx.fire_write(TaggedMessageEvent { + now: transmit.now, + transport: TransportContext { + local_addr: self.local_addr, + peer_addr: transmit.remote, + ecn: transmit.ecn, + }, + message: MessageEvent::Dtls(DTLSMessageEvent::Raw( + BytesMut::from(&raw[..]), + )), + }); + } + } + } + } } } Err(err) => { @@ -181,7 +190,6 @@ impl InboundHandler for SctpInbound { ctx.fire_read_exception(Box::new(err)) } }; - handle_outgoing(ctx, &mut self.transmits, msg.transport.local_addr); } else { // Bypass debug!("bypass sctp read {:?}", msg.transport.peer_addr); @@ -190,48 +198,82 @@ impl InboundHandler for SctpInbound { } fn handle_timeout(&mut self, ctx: &InboundContext, now: Instant) { - let mut try_timeout = || -> Result<()> { - let mut endpoint_events: Vec<(AssociationHandle, EndpointEvent)> = vec![]; + let try_timeout = || -> Result> { + let mut transmits = vec![]; let mut server_states = self.server_states.borrow_mut(); - let sctp_associations = server_states.get_mut_sctp_associations(); - for (ch, conn) in sctp_associations.iter_mut() { - conn.handle_timeout(now); + for session in server_states.get_mut_sessions().values_mut() { + for endpoint in session.get_mut_endpoints().values_mut() { + for transport in endpoint.get_mut_transports().values_mut() { + let (sctp_endpoint, sctp_associations) = + transport.get_mut_sctp_endpoint_associations(); - while let Some(event) = conn.poll_endpoint_event() { - endpoint_events.push((*ch, event)); - } + let mut endpoint_events: Vec<(AssociationHandle, EndpointEvent)> = vec![]; + for (ch, conn) in sctp_associations.iter_mut() { + conn.handle_timeout(now); - while let Some(x) = conn.poll_transmit(now) { - self.transmits.extend(split_transmit(x)); - } - } + while let Some(event) = conn.poll_endpoint_event() { + endpoint_events.push((*ch, event)); + } + + while let Some(x) = conn.poll_transmit(now) { + transmits.extend(split_transmit(x)); + } + } - let mut sctp_endpoint = self.sctp_endpoint.borrow_mut(); - for (ch, event) in endpoint_events { - sctp_endpoint.handle_event(ch, event); // handle drain event - sctp_associations.remove(&ch); + for (ch, event) in endpoint_events { + sctp_endpoint.handle_event(ch, event); // handle drain event + sctp_associations.remove(&ch); + } + } + } } - Ok(()) + Ok(transmits) }; - if let Err(err) = try_timeout() { - error!("try_timeout with error {}", err); - ctx.fire_read_exception(Box::new(err)); + match try_timeout() { + Ok(transmits) => { + for transmit in transmits { + if let Payload::RawEncode(raw_data) = transmit.payload { + for raw in raw_data { + ctx.fire_write(TaggedMessageEvent { + now: transmit.now, + transport: TransportContext { + local_addr: self.local_addr, + peer_addr: transmit.remote, + ecn: transmit.ecn, + }, + message: MessageEvent::Dtls(DTLSMessageEvent::Raw(BytesMut::from( + &raw[..], + ))), + }); + } + } + } + } + Err(err) => { + error!("try_timeout with error {}", err); + ctx.fire_read_exception(Box::new(err)); + } } - handle_outgoing(ctx, &mut self.transmits, self.local_addr); ctx.fire_handle_timeout(now); } fn poll_timeout(&mut self, ctx: &InboundContext, eto: &mut Instant) { { - let mut server_states = self.server_states.borrow_mut(); - let sctp_associations = server_states.get_mut_sctp_associations(); - for (_, conn) in sctp_associations.iter_mut() { - if let Some(timeout) = conn.poll_timeout() { - if timeout < *eto { - *eto = timeout; + let server_states = self.server_states.borrow(); + for session in server_states.get_sessions().values() { + for endpoint in session.get_endpoints().values() { + for transport in endpoint.get_transports().values() { + let sctp_associations = transport.get_sctp_associations(); + for conn in sctp_associations.values() { + if let Some(timeout) = conn.poll_timeout() { + if timeout < *eto { + *eto = timeout; + } + } + } } } } @@ -250,7 +292,10 @@ impl OutboundHandler for SctpOutbound { "send sctp data channel message {:?}", msg.transport.peer_addr ); - let mut try_write = || -> Result<()> { + let four_tuple = (&msg.transport).into(); + + let try_write = || -> Result> { + let mut transmits = vec![]; let mut server_states = self.server_states.borrow_mut(); let max_message_size = { server_states @@ -263,7 +308,8 @@ impl OutboundHandler for SctpOutbound { return Err(Error::ErrOutboundPacketTooLarge); } - let sctp_associations = server_states.get_mut_sctp_associations(); + let transport = server_states.get_mut_transport(&four_tuple)?; + let sctp_associations = transport.get_mut_sctp_associations(); if let Some(conn) = sctp_associations.get_mut(&AssociationHandle(message.association_handle)) { @@ -292,38 +338,45 @@ impl OutboundHandler for SctpOutbound { )?; while let Some(x) = conn.poll_transmit(msg.now) { - self.transmits.extend(split_transmit(x)); + transmits.extend(split_transmit(x)); } } } else { return Err(Error::ErrAssociationNotExisted); } - Ok(()) + Ok(transmits) }; - if let Err(err) = try_write() { - error!("try_write with error {}", err); - ctx.fire_write_exception(Box::new(err)); + match try_write() { + Ok(transmits) => { + for transmit in transmits { + if let Payload::RawEncode(raw_data) = transmit.payload { + for raw in raw_data { + ctx.fire_write(TaggedMessageEvent { + now: transmit.now, + transport: TransportContext { + local_addr: self.local_addr, + peer_addr: transmit.remote, + ecn: transmit.ecn, + }, + message: MessageEvent::Dtls(DTLSMessageEvent::Raw( + BytesMut::from(&raw[..]), + )), + }); + } + } + } + } + Err(err) => { + error!("try_write with error {}", err); + ctx.fire_write_exception(Box::new(err)); + } } - handle_outgoing(ctx, &mut self.transmits, msg.transport.local_addr); } else { // Bypass debug!("Bypass sctp write {:?}", msg.transport.peer_addr); ctx.fire_write(msg); } } - - fn close(&mut self, ctx: &OutboundContext) { - { - let mut server_states = self.server_states.borrow_mut(); - let sctp_associations = server_states.get_mut_sctp_associations(); - for (_, conn) in sctp_associations.iter_mut() { - let _ = conn.close(); - } - } - handle_outgoing(ctx, &mut self.transmits, self.local_addr); - - ctx.fire_close(); - } } impl Handler for SctpHandler { @@ -346,28 +399,6 @@ impl Handler for SctpHandler { } } -fn handle_outgoing( - ctx: &OutboundContext, - transmits: &mut VecDeque, - local_addr: SocketAddr, -) { - while let Some(transmit) = transmits.pop_front() { - if let Payload::RawEncode(raw_data) = transmit.payload { - for raw in raw_data { - ctx.fire_write(TaggedMessageEvent { - now: transmit.now, - transport: TransportContext { - local_addr, - peer_addr: transmit.remote, - ecn: transmit.ecn, - }, - message: MessageEvent::Dtls(DTLSMessageEvent::Raw(BytesMut::from(&raw[..]))), - }); - } - } - } -} - fn split_transmit(transmit: Transmit) -> Vec { let mut transmits = Vec::new(); if let Payload::RawEncode(contents) = transmit.payload { diff --git a/src/server/states.rs b/src/server/states.rs index 27d2dd2..0d5cf2c 100644 --- a/src/server/states.rs +++ b/src/server/states.rs @@ -7,7 +7,6 @@ use crate::endpoint::{ use crate::server::config::ServerConfig; use crate::session::{config::SessionConfig, Session}; use crate::types::{EndpointId, FourTuple, SessionId, UserName}; -use sctp::{Association, AssociationHandle}; use shared::error::{Error, Result}; use std::collections::hash_map::Entry; use std::collections::HashMap; @@ -24,7 +23,6 @@ pub struct ServerStates { //TODO: add idle timeout cleanup logic to remove idle endpoint and candidates candidates: HashMap>, endpoints: HashMap, - sctp_associations: HashMap, } impl ServerStates { @@ -45,8 +43,6 @@ impl ServerStates { candidates: HashMap::new(), endpoints: HashMap::new(), - - sctp_associations: HashMap::new(), }) } @@ -136,16 +132,6 @@ impl ServerStates { self.local_addr } - pub(crate) fn get_sctp_associations(&self) -> &HashMap { - &self.sctp_associations - } - - pub(crate) fn get_mut_sctp_associations( - &mut self, - ) -> &mut HashMap { - &mut self.sctp_associations - } - pub(crate) fn create_or_get_mut_session(&mut self, session_id: SessionId) -> &mut Session { if let Entry::Vacant(e) = self.sessions.entry(session_id) { let session = Session::new(