diff --git a/src/handlers/gateway/mod.rs b/src/handlers/gateway/mod.rs index 56b8e6d..4982503 100644 --- a/src/handlers/gateway/mod.rs +++ b/src/handlers/gateway/mod.rs @@ -1,11 +1,22 @@ use crate::messages::{MessageEvent, STUNMessageEvent, TaggedMessageEvent}; +use crate::server::endpoint::transport::Transport; +use crate::server::endpoint::{candidate::Candidate, Endpoint}; use crate::server::states::ServerStates; -use log::warn; +use log::{debug, warn}; use retty::channel::{Handler, InboundContext, InboundHandler, OutboundContext, OutboundHandler}; use retty::transport::TransportContext; -use shared::error::Result; +use shared::error::{Error, Result}; use std::rc::Rc; use std::time::Instant; +use stun::attributes::{ + ATTR_ICE_CONTROLLED, ATTR_ICE_CONTROLLING, ATTR_NETWORK_COST, ATTR_PRIORITY, ATTR_USERNAME, + ATTR_USE_CANDIDATE, +}; +use stun::fingerprint::FINGERPRINT; +use stun::integrity::MessageIntegrity; +use stun::message::{Setter, TransactionId, BINDING_SUCCESS}; +use stun::textattrs::TextAttribute; +use stun::xoraddr::XorMappedAddress; struct GatewayInbound { server_states: Rc, @@ -35,17 +46,17 @@ impl InboundHandler for GatewayInbound { type Rout = Self::Rin; fn read(&mut self, ctx: &InboundContext, msg: Self::Rin) { - let try_read = || -> Result<()> { + let try_read = || -> Result> { match msg.message { - MessageEvent::STUN(STUNMessageEvent::STUN(message)) => { - self.handle_stun_message(ctx, msg.now, msg.transport, message) - } + MessageEvent::STUN(STUNMessageEvent::STUN(message)) => Ok(Some( + self.handle_stun_message(msg.now, msg.transport, message)?, + )), _ => { warn!( "drop unsupported message {:?} from {}", msg.message, msg.transport.peer_addr ); - Ok(()) + Ok(None) } } }; @@ -92,11 +103,163 @@ impl Handler for GatewayHandler { impl GatewayInbound { fn handle_stun_message( &mut self, - _ctx: &InboundContext, - _now: Instant, - _transport: TransportContext, - _request: stun::message::Message, - ) -> Result<()> { - Ok(()) + now: Instant, + transport_context: TransportContext, + mut request: stun::message::Message, + ) -> Result { + let candidate = match self.check_stun_message(&mut request)? { + Some(candidate) => candidate, + None => { + return self.send_server_reflective_address( + now, + transport_context, + request.transaction_id, + ); + } + }; + + let (_is_new_endpoint, _endpoint, _transport) = + self.add_endpoint(&request, &candidate, &transport_context)?; + + let mut response = stun::message::Message::new(); + response.build(&[ + Box::new(BINDING_SUCCESS), + Box::new(request.transaction_id), + Box::new(XorMappedAddress { + ip: transport_context.peer_addr.ip(), + port: transport_context.peer_addr.port(), + }), + ])?; + let integrity = MessageIntegrity::new_short_term_integrity( + candidate.get_local_parameters().password.clone(), + ); + integrity.add_to(&mut response)?; + FINGERPRINT.add_to(&mut response)?; + + debug!( + "handle_stun_message response type {} with ip {} and port {} sent", + response.typ, + transport_context.peer_addr.ip(), + transport_context.peer_addr.port() + ); + + Ok(TaggedMessageEvent { + now, + transport: transport_context, + message: MessageEvent::STUN(STUNMessageEvent::STUN(response)), + }) + } + + fn check_stun_message( + &self, + request: &mut stun::message::Message, + ) -> Result>> { + match TextAttribute::get_from_as(request, ATTR_USERNAME) { + Ok(username) => { + if !request.contains(ATTR_PRIORITY) { + return Err(Error::Other( + "invalid STUN message without ATTR_PRIORITY".to_string(), + )); + } + + if request.contains(ATTR_ICE_CONTROLLING) { + if request.contains(ATTR_ICE_CONTROLLED) { + return Err(Error::Other("invalid STUN message with both ATTR_ICE_CONTROLLING and ATTR_ICE_CONTROLLED".to_string())); + } + } else if request.contains(ATTR_ICE_CONTROLLED) { + if request.contains(ATTR_USE_CANDIDATE) { + return Err(Error::Other("invalid STUN message with both ATTR_USE_CANDIDATE and ATTR_ICE_CONTROLLED".to_string())); + } + } else { + return Err(Error::Other( + "invalid STUN message without ATTR_ICE_CONTROLLING or ATTR_ICE_CONTROLLED" + .to_string(), + )); + } + + if let Some(candidate) = self.server_states.find_candidate(&username.text) { + let password = candidate.get_local_parameters().password.clone(); + let integrity = MessageIntegrity::new_short_term_integrity(password); + integrity.check(request)?; + Ok(Some(candidate)) + } else { + Err(Error::Other("username not found".to_string())) + } + } + Err(_) => { + if request.contains(ATTR_ICE_CONTROLLED) + || request.contains(ATTR_ICE_CONTROLLING) + || request.contains(ATTR_NETWORK_COST) + || request.contains(ATTR_PRIORITY) + || request.contains(ATTR_USE_CANDIDATE) + { + Err(Error::Other("unexpected attribute".to_string())) + } else { + Ok(None) + } + } + } + } + + fn send_server_reflective_address( + &mut self, + now: Instant, + transport_context: TransportContext, + transaction_id: TransactionId, + ) -> Result { + let mut response = stun::message::Message::new(); + response.build(&[ + Box::new(BINDING_SUCCESS), + Box::new(transaction_id), + Box::new(XorMappedAddress { + ip: transport_context.peer_addr.ip(), + port: transport_context.peer_addr.port(), + }), + ])?; + + debug!( + "send_server_reflective_address response type {} sent", + response.typ + ); + + Ok(TaggedMessageEvent { + now, + transport: transport_context, + message: MessageEvent::STUN(STUNMessageEvent::STUN(response)), + }) + } + + #[allow(clippy::type_complexity)] + fn add_endpoint( + &mut self, + request: &stun::message::Message, + candidate: &Candidate, + transport_context: &TransportContext, + ) -> Result<(bool, Option>, Option>)> { + let mut is_new_endpoint = false; + + let session_id = candidate.session_id(); + let session = self + .server_states + .get_session(&session_id) + .ok_or(Error::Other(format!("session {} not found", session_id)))?; + + let endpoint_id = candidate.endpoint_id(); + let endpoint = session.get_endpoint(&endpoint_id); + let transport = if let Some(endpoint) = &endpoint { + let four_tuple = transport_context.into(); + endpoint.get_transport(&four_tuple) + } else { + is_new_endpoint = true; + None + }; + + if !request.contains(ATTR_USE_CANDIDATE) || transport.is_some() { + return Ok((is_new_endpoint, endpoint, transport)); + } + + //todo:session.add_endpoint(candidate, transport_context); + + Ok((is_new_endpoint, None, None)) } } diff --git a/src/server/endpoint/mod.rs b/src/server/endpoint/mod.rs index ea63ec5..03f1813 100644 --- a/src/server/endpoint/mod.rs +++ b/src/server/endpoint/mod.rs @@ -1,13 +1,18 @@ pub mod candidate; +pub mod transport; +use crate::server::endpoint::transport::Transport; use crate::server::session::Session; -use crate::types::EndpointId; -use std::rc::Weak; +use crate::types::{EndpointId, FourTuple}; +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::{Rc, Weak}; -#[derive(Default, Debug)] +#[derive(Default, Debug, Clone)] pub struct Endpoint { session: Weak, endpoint_id: EndpointId, + transports: RefCell>>, } impl Endpoint { @@ -15,6 +20,7 @@ impl Endpoint { Self { session, endpoint_id, + transports: RefCell::new(HashMap::new()), } } @@ -25,4 +31,13 @@ impl Endpoint { pub fn endpoint_id(&self) -> EndpointId { self.endpoint_id } + + pub fn add_transport(&self, transport: Rc) { + let mut transports = self.transports.borrow_mut(); + transports.insert(*transport.four_tuple(), transport); + } + + pub fn get_transport(&self, four_tuple: &FourTuple) -> Option> { + self.transports.borrow().get(four_tuple).cloned() + } } diff --git a/src/server/endpoint/transport.rs b/src/server/endpoint/transport.rs new file mode 100644 index 0000000..547024c --- /dev/null +++ b/src/server/endpoint/transport.rs @@ -0,0 +1,37 @@ +use crate::server::endpoint::candidate::Candidate; +use crate::server::endpoint::Endpoint; +use crate::types::FourTuple; +use std::rc::{Rc, Weak}; + +#[derive(Debug, Clone)] +pub struct Transport { + four_tuple: FourTuple, + endpoint: Weak, + candidate: Rc, +} + +impl Transport { + pub(crate) fn new( + four_tuple: FourTuple, + endpoint: Weak, + candidate: Rc, + ) -> Self { + Self { + four_tuple, + endpoint, + candidate, + } + } + + pub(crate) fn four_tuple(&self) -> &FourTuple { + &self.four_tuple + } + + pub(crate) fn endpoint(&self) -> &Weak { + &self.endpoint + } + + pub(crate) fn candidate(&self) -> &Rc { + &self.candidate + } +} diff --git a/src/server/session/mod.rs b/src/server/session/mod.rs index 2932782..c412281 100644 --- a/src/server/session/mod.rs +++ b/src/server/session/mod.rs @@ -1,3 +1,4 @@ +//use retty::transport::TransportContext; use sdp::description::session::Origin; use sdp::util::ConnectionRole; use sdp::SessionDescription; @@ -10,7 +11,8 @@ use std::rc::Rc; pub mod description; use crate::server::certificate::RTCCertificate; -use crate::server::endpoint::candidate::{DTLSRole, RTCIceParameters}; +use crate::server::endpoint::candidate::{/*Candidate,*/ DTLSRole, RTCIceParameters}; +//use crate::server::endpoint::transport::Transport; use crate::server::endpoint::Endpoint; use crate::server::session::description::rtp_codec::RTPCodecType; use crate::server::session::description::rtp_transceiver::RTCRtpTransceiver; @@ -49,7 +51,37 @@ impl Session { self.session_id } - pub fn create_pending_answer( + /* + pub(crate) fn add_endpoint( + &mut self, + candidate: &Rc, + transport_context: &TransportContext, + ) -> Result<(bool, Rc, Rc)> { + let endpoint_id = candidate.endpoint_id(); + let endpoint = self.get_endpoint(&endpoint_id); + if let Some(endpoint) = endpoint { + let four_tuple = transport_context.into(); + if let Some(transport) = endpoint.get_transport(&four_tuple) { + Ok((true, endpoint, transport)) + } else { + let transport = Rc::new(Transport::new( + four_tuple, + Rc::downgrade(&endpoint), + Rc::clone(&candidate), + )); + endpoint.add_transport(Rc::clone(&transport)); + Ok((true, endpoint, transport)) + } + } else { + Ok((false, endpoint, transport)) + } + }*/ + + pub(crate) fn get_endpoint(&self, endpoint_id: &EndpointId) -> Option> { + self.endpoints.borrow().get(endpoint_id).cloned() + } + + pub(crate) fn create_pending_answer( &self, _endpoint_id: EndpointId, remote_description: &RTCSessionDescription, diff --git a/src/server/states.rs b/src/server/states.rs index 4d97410..d6d6d05 100644 --- a/src/server/states.rs +++ b/src/server/states.rs @@ -18,6 +18,7 @@ pub struct ServerStates { sessions: RefCell>>, // Thread-local map + //TODO: add idle timeout cleanup logic to remove idle endpoint and candidates endpoints: RefCell>>, candidates: RefCell>>, } @@ -61,6 +62,10 @@ impl ServerStates { } } + pub(crate) fn get_session(&self, session_id: &SessionId) -> Option> { + self.sessions.borrow().get(session_id).cloned() + } + // set pending offer and return answer pub fn accept_pending_offer( &self, diff --git a/src/types.rs b/src/types.rs index b4e7cc9..6173949 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,6 +3,7 @@ use std::net::SocketAddr; pub type SessionId = u64; pub type EndpointId = u64; + pub type UserName = String; #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]