From e01be38d6203c3c69e0b7e043703fde17ef649c7 Mon Sep 17 00:00:00 2001 From: yngrtc Date: Thu, 7 Mar 2024 16:56:35 -0800 Subject: [PATCH] remove endpoint/session when dtls Error::ErrAlertFatalOrClose or idle time --- examples/sync_chat.rs | 4 ++- examples/sync_signal/mod.rs | 7 +++-- src/endpoint/mod.rs | 4 +-- src/endpoint/transport.rs | 11 ++++++++ src/handler/demuxer.rs | 2 +- src/handler/dtls.rs | 7 ++++- src/handler/exception.rs | 2 +- src/handler/gateway.rs | 55 +++++++++++++++++++++++++++++++++++-- src/handler/stun.rs | 2 +- src/server/config.rs | 18 ++++-------- src/server/states.rs | 42 ++++++++++++++++++++++++++-- src/session/mod.rs | 4 +++ 12 files changed, 131 insertions(+), 27 deletions(-) diff --git a/examples/sync_chat.rs b/examples/sync_chat.rs index 7341c14..a5570db 100644 --- a/examples/sync_chat.rs +++ b/examples/sync_chat.rs @@ -8,6 +8,7 @@ use std::net::{IpAddr, UdpSocket}; use std::str::FromStr; use std::sync::mpsc::{self}; use std::sync::Arc; +use std::time::Duration; use wg::WaitGroup; mod sync_signal; @@ -114,7 +115,8 @@ pub fn main() -> anyhow::Result<()> { ServerConfig::new(certificates) .with_dtls_handshake_config(dtls_handshake_config) .with_sctp_endpoint_config(sctp_endpoint_config) - .with_sctp_server_config(sctp_server_config), + .with_sctp_server_config(sctp_server_config) + .with_idle_timeout(Duration::from_secs(30)), ); let wait_group = WaitGroup::new(); diff --git a/examples/sync_signal/mod.rs b/examples/sync_signal/mod.rs index 3c98b72..efac884 100644 --- a/examples/sync_signal/mod.rs +++ b/examples/sync_signal/mod.rs @@ -107,6 +107,7 @@ pub fn sync_run( let mut buf = vec![0; 2000]; + pipeline.transport_active(); loop { match stop_rx.try_recv() { Ok(_) => break, @@ -150,6 +151,7 @@ pub fn sync_run( // Drive time forward in all clients. pipeline.handle_timeout(Instant::now()); } + pipeline.transport_inactive(); println!( "media server on {} is gracefully down", @@ -161,10 +163,9 @@ pub fn sync_run( fn write_socket_output( socket: &UdpSocket, pipeline: &Rc>, -) -> anyhow::Result<()>{ +) -> anyhow::Result<()> { while let Some(transmit) = pipeline.poll_transmit() { - socket - .send_to(&transmit.message, transmit.transport.peer_addr)?; + socket.send_to(&transmit.message, transmit.transport.peer_addr)?; } Ok(()) diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index e8d9ed9..b7c295b 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -46,8 +46,8 @@ impl Endpoint { self.transports.insert(*transport.four_tuple(), transport); } - pub(crate) fn remove_transport(&mut self, four_tuple: &FourTuple) { - self.transports.remove(four_tuple); + pub(crate) fn remove_transport(&mut self, four_tuple: &FourTuple) -> Option { + self.transports.remove(four_tuple) } pub(crate) fn has_transport(&self, four_tuple: &FourTuple) -> bool { diff --git a/src/endpoint/transport.rs b/src/endpoint/transport.rs index b0f2f8a..b7b747c 100644 --- a/src/endpoint/transport.rs +++ b/src/endpoint/transport.rs @@ -5,9 +5,11 @@ use srtp::context::Context; use std::collections::HashMap; use std::rc::Rc; use std::sync::Arc; +use std::time::Instant; pub(crate) struct Transport { four_tuple: FourTuple, + last_activity: Instant, // ICE candidate: Rc, @@ -38,6 +40,7 @@ impl Transport { ) -> Self { Self { four_tuple, + last_activity: Instant::now(), candidate, @@ -129,4 +132,12 @@ impl Transport { pub(crate) fn is_local_srtp_context_ready(&self) -> bool { self.local_srtp_context.is_some() } + + pub(crate) fn keep_alive(&mut self) { + self.last_activity = Instant::now(); + } + + pub(crate) fn last_activity(&self) -> Instant { + self.last_activity + } } diff --git a/src/handler/demuxer.rs b/src/handler/demuxer.rs index b15e01a..2da5196 100644 --- a/src/handler/demuxer.rs +++ b/src/handler/demuxer.rs @@ -45,7 +45,7 @@ pub struct DemuxerHandler; impl DemuxerHandler { pub fn new() -> Self { - DemuxerHandler::default() + DemuxerHandler } } diff --git a/src/handler/dtls.rs b/src/handler/dtls.rs index b5af71f..60890b1 100644 --- a/src/handler/dtls.rs +++ b/src/handler/dtls.rs @@ -132,7 +132,12 @@ impl Handler for DtlsHandler { } Err(err) => { error!("try_read with error {}", err); - ctx.fire_exception(Box::new(err)) + if err == Error::ErrAlertFatalOrClose { + let mut server_states = self.server_states.borrow_mut(); + server_states.remove_transport(four_tuple); + } else { + ctx.fire_exception(Box::new(err)) + } } }; } else { diff --git a/src/handler/exception.rs b/src/handler/exception.rs index 164ed08..faf576f 100644 --- a/src/handler/exception.rs +++ b/src/handler/exception.rs @@ -9,7 +9,7 @@ pub struct ExceptionHandler; impl ExceptionHandler { pub fn new() -> Self { - ExceptionHandler::default() + ExceptionHandler } } diff --git a/src/handler/gateway.rs b/src/handler/gateway.rs index 0f19ea3..bdd3ff5 100644 --- a/src/handler/gateway.rs +++ b/src/handler/gateway.rs @@ -15,8 +15,10 @@ use retty::transport::TransportContext; use shared::error::{Error, Result}; use std::cell::RefCell; use std::collections::VecDeque; +use std::ops::{Add, Sub}; use std::rc::Rc; use std::time::Instant; +use std::time::Duration; use stun::attributes::{ ATTR_ICE_CONTROLLED, ATTR_ICE_CONTROLLING, ATTR_NETWORK_COST, ATTR_PRIORITY, ATTR_USERNAME, ATTR_USE_CANDIDATE, @@ -31,13 +33,19 @@ use stun::xoraddr::XorMappedAddress; pub struct GatewayHandler { server_states: Rc>, transmits: VecDeque, + next_timeout: Instant, + idle_timeout: Duration, } impl GatewayHandler { pub fn new(server_states: Rc>) -> Self { + let idle_timeout = server_states.borrow().server_config().idle_timeout; + GatewayHandler { server_states, transmits: VecDeque::new(), + next_timeout: Instant::now().add(idle_timeout), + idle_timeout, } } } @@ -52,6 +60,23 @@ impl Handler for GatewayHandler { "GatewayHandler" } + fn transport_inactive(&mut self, _ctx: &Context) { + let server_states = self.server_states.borrow(); + let sessions = server_states.get_sessions(); + let mut endpoint_count = 0; + for session in sessions.values() { + endpoint_count += session.get_endpoints().len(); + } + info!( + "Still Active Sessions {}, Endpoints {}/{}, Candidates {} on {}", + sessions.len(), + endpoint_count, + server_states.get_endpoints().len(), + server_states.get_candidates().len(), + server_states.local_addr() + ); + } + fn handle_read( &mut self, ctx: &Context, @@ -115,9 +140,34 @@ impl Handler for GatewayHandler { fn handle_timeout( &mut self, _ctx: &Context, - _now: Instant, + now: Instant, ) { // terminate timeout here, no more ctx.fire_handle_timeout(now); + if self.next_timeout <= now { + let mut four_tuples = vec![]; + let mut server_states = self.server_states.borrow_mut(); + for session in server_states.get_mut_sessions().values_mut() { + for endpoint in session.get_mut_endpoints().values_mut() { + for transport in endpoint.get_mut_transports().values_mut() { + if transport.last_activity() <= now.sub(self.idle_timeout) { + four_tuples.push(*transport.four_tuple()); + } + } + } + } + for four_tuple in four_tuples { + server_states.remove_transport(four_tuple); + } + + self.next_timeout = self.next_timeout.add(self.idle_timeout); + } + } + + fn poll_timeout(&mut self, ctx: &Context, eto: &mut Instant) { + if self.next_timeout < *eto { + *eto = self.next_timeout; + } + ctx.fire_poll_timeout(eto); } fn poll_write( @@ -127,7 +177,6 @@ impl Handler for GatewayHandler { if let Some(msg) = ctx.fire_poll_write() { self.transmits.push_back(msg); } - self.transmits.pop_front() } } @@ -387,6 +436,7 @@ impl GatewayHandler { rtp_packet: rtp::packet::Packet, ) -> Result> { debug!("handle_rtp_message {}", transport_context.peer_addr); + server_states.get_mut_transport(&(&transport_context).into())?.keep_alive(); //TODO: Selective Forwarding RTP Packets let peers = @@ -411,6 +461,7 @@ impl GatewayHandler { rtcp_packets: Vec>, ) -> Result> { debug!("handle_rtcp_message {}", transport_context.peer_addr); + server_states.get_mut_transport(&(&transport_context).into())?.keep_alive(); //TODO: Selective Forwarding RTCP Packets let peers = diff --git a/src/handler/stun.rs b/src/handler/stun.rs index f3e776b..d1a8fc3 100644 --- a/src/handler/stun.rs +++ b/src/handler/stun.rs @@ -11,7 +11,7 @@ pub struct StunHandler; impl StunHandler { pub fn new() -> Self { - StunHandler::default() + StunHandler } } diff --git a/src/server/config.rs b/src/server/config.rs index ce3b5a7..b1ad4ff 100644 --- a/src/server/config.rs +++ b/src/server/config.rs @@ -10,8 +10,7 @@ pub struct ServerConfig { pub(crate) sctp_endpoint_config: Arc, pub(crate) sctp_server_config: Arc, pub(crate) media_config: MediaConfig, - pub(crate) endpoint_idle_timeout: Duration, - pub(crate) candidate_idle_timeout: Duration, + pub(crate) idle_timeout: Duration, } impl ServerConfig { @@ -23,8 +22,7 @@ impl ServerConfig { sctp_endpoint_config: Arc::new(sctp::EndpointConfig::default()), sctp_server_config: Arc::new(sctp::ServerConfig::default()), dtls_handshake_config: Arc::new(dtls::config::HandshakeConfig::default()), - endpoint_idle_timeout: Duration::from_secs(30), - candidate_idle_timeout: Duration::from_secs(30), + idle_timeout: Duration::from_secs(30), } } @@ -58,15 +56,9 @@ impl ServerConfig { self } - /// build with endpoint idle timeout - pub fn with_endpoint_idle_timeout(mut self, endpoint_idle_timeout: Duration) -> Self { - self.endpoint_idle_timeout = endpoint_idle_timeout; - self - } - - /// build with candidate idle timeout - pub fn with_candidate_idle_timeout(mut self, candidate_idle_timeout: Duration) -> Self { - self.candidate_idle_timeout = candidate_idle_timeout; + /// build with idle timeout + pub fn with_idle_timeout(mut self, idle_timeout: Duration) -> Self { + self.idle_timeout = idle_timeout; self } } diff --git a/src/server/states.rs b/src/server/states.rs index 95f6b69..5b06a64 100644 --- a/src/server/states.rs +++ b/src/server/states.rs @@ -7,7 +7,7 @@ use crate::endpoint::{ use crate::server::config::ServerConfig; use crate::session::{config::SessionConfig, Session}; use crate::types::{EndpointId, FourTuple, SessionId, UserName}; -use log::info; +use log::{debug, info}; use shared::error::{Error, Result}; use std::collections::hash_map::Entry; use std::collections::HashMap; @@ -101,7 +101,7 @@ impl ServerStates { local_conn_cred, offer, answer.clone(), - Instant::now() + self.server_config.candidate_idle_timeout, + Instant::now() + self.server_config.idle_timeout, ))); } @@ -162,6 +162,10 @@ impl ServerStates { self.sessions.get_mut(session_id) } + pub(crate) fn remove_session(&mut self, session_id: &SessionId) -> Option { + self.sessions.remove(session_id) + } + pub(crate) fn add_candidate(&mut self, candidate: Rc) -> Option> { let username = candidate.username(); self.candidates.insert(username, candidate) @@ -175,6 +179,14 @@ impl ServerStates { self.candidates.get(username) } + pub(crate) fn get_candidates(&self) -> &HashMap> { + &self.candidates + } + + pub(crate) fn get_endpoints(&self) -> &HashMap { + &self.endpoints + } + pub(crate) fn add_endpoint( &mut self, four_tuple: FourTuple, @@ -242,4 +254,30 @@ impl ServerStates { Ok(transport) } + + pub(crate) fn remove_transport(&mut self, four_tuple: FourTuple) { + debug!("remove idle transport {:?}", four_tuple); + + let Some((session_id, endpoint_id)) = self.find_endpoint(&four_tuple) else { + return; + }; + let Some(session) = self.get_mut_session(&session_id) else { + return; + }; + let Some(endpoint) = session.get_mut_endpoint(&endpoint_id) else { + return; + }; + + let transport = endpoint.remove_transport(&four_tuple); + if endpoint.get_transports().is_empty() { + session.remove_endpoint(&endpoint_id); + if session.get_endpoints().is_empty() { + self.remove_session(&session_id); + } + self.remove_endpoint(&four_tuple); + } + if let Some(transport) = transport { + self.remove_candidate(&transport.candidate().username()); + } + } } diff --git a/src/session/mod.rs b/src/session/mod.rs index 78bf973..a19ff6f 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -110,6 +110,10 @@ impl Session { self.endpoints.get_mut(endpoint_id) } + pub(crate) fn remove_endpoint(&mut self, endpoint_id: &EndpointId) -> Option { + self.endpoints.remove(endpoint_id) + } + pub(crate) fn has_endpoint(&self, endpoint_id: &EndpointId) -> bool { self.endpoints.contains_key(endpoint_id) }