Skip to content

Commit

Permalink
remove endpoint/session when dtls Error::ErrAlertFatalOrClose or idle…
Browse files Browse the repository at this point in the history
… time
  • Loading branch information
yngrtc committed Mar 8, 2024
1 parent d6c7c0b commit e01be38
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 27 deletions.
4 changes: 3 additions & 1 deletion examples/sync_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::net::{IpAddr, UdpSocket};
use std::str::FromStr;
use std::sync::mpsc::{self};
use std::sync::Arc;
use std::time::Duration;
use wg::WaitGroup;

mod sync_signal;
Expand Down Expand Up @@ -114,7 +115,8 @@ pub fn main() -> anyhow::Result<()> {
ServerConfig::new(certificates)
.with_dtls_handshake_config(dtls_handshake_config)
.with_sctp_endpoint_config(sctp_endpoint_config)
.with_sctp_server_config(sctp_server_config),
.with_sctp_server_config(sctp_server_config)
.with_idle_timeout(Duration::from_secs(30)),
);
let wait_group = WaitGroup::new();

Expand Down
7 changes: 4 additions & 3 deletions examples/sync_signal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ pub fn sync_run(

let mut buf = vec![0; 2000];

pipeline.transport_active();
loop {
match stop_rx.try_recv() {
Ok(_) => break,
Expand Down Expand Up @@ -150,6 +151,7 @@ pub fn sync_run(
// Drive time forward in all clients.
pipeline.handle_timeout(Instant::now());
}
pipeline.transport_inactive();

println!(
"media server on {} is gracefully down",
Expand All @@ -161,10 +163,9 @@ pub fn sync_run(
fn write_socket_output(
socket: &UdpSocket,
pipeline: &Rc<Pipeline<TaggedBytesMut, TaggedBytesMut>>,
) -> anyhow::Result<()>{
) -> anyhow::Result<()> {
while let Some(transmit) = pipeline.poll_transmit() {
socket
.send_to(&transmit.message, transmit.transport.peer_addr)?;
socket.send_to(&transmit.message, transmit.transport.peer_addr)?;
}

Ok(())
Expand Down
4 changes: 2 additions & 2 deletions src/endpoint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ impl Endpoint {
self.transports.insert(*transport.four_tuple(), transport);
}

pub(crate) fn remove_transport(&mut self, four_tuple: &FourTuple) {
self.transports.remove(four_tuple);
pub(crate) fn remove_transport(&mut self, four_tuple: &FourTuple) -> Option<Transport> {
self.transports.remove(four_tuple)
}

pub(crate) fn has_transport(&self, four_tuple: &FourTuple) -> bool {
Expand Down
11 changes: 11 additions & 0 deletions src/endpoint/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ use srtp::context::Context;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Instant;

pub(crate) struct Transport {
four_tuple: FourTuple,
last_activity: Instant,

// ICE
candidate: Rc<Candidate>,
Expand Down Expand Up @@ -38,6 +40,7 @@ impl Transport {
) -> Self {
Self {
four_tuple,
last_activity: Instant::now(),

candidate,

Expand Down Expand Up @@ -129,4 +132,12 @@ impl Transport {
pub(crate) fn is_local_srtp_context_ready(&self) -> bool {
self.local_srtp_context.is_some()
}

pub(crate) fn keep_alive(&mut self) {
self.last_activity = Instant::now();
}

pub(crate) fn last_activity(&self) -> Instant {
self.last_activity
}
}
2 changes: 1 addition & 1 deletion src/handler/demuxer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub struct DemuxerHandler;

impl DemuxerHandler {
pub fn new() -> Self {
DemuxerHandler::default()
DemuxerHandler
}
}

Expand Down
7 changes: 6 additions & 1 deletion src/handler/dtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ impl Handler for DtlsHandler {
}
Err(err) => {
error!("try_read with error {}", err);
ctx.fire_exception(Box::new(err))
if err == Error::ErrAlertFatalOrClose {
let mut server_states = self.server_states.borrow_mut();
server_states.remove_transport(four_tuple);
} else {
ctx.fire_exception(Box::new(err))
}
}
};
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/handler/exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub struct ExceptionHandler;

impl ExceptionHandler {
pub fn new() -> Self {
ExceptionHandler::default()
ExceptionHandler
}
}

Expand Down
55 changes: 53 additions & 2 deletions src/handler/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ use retty::transport::TransportContext;
use shared::error::{Error, Result};
use std::cell::RefCell;
use std::collections::VecDeque;
use std::ops::{Add, Sub};
use std::rc::Rc;
use std::time::Instant;
use std::time::Duration;
use stun::attributes::{
ATTR_ICE_CONTROLLED, ATTR_ICE_CONTROLLING, ATTR_NETWORK_COST, ATTR_PRIORITY, ATTR_USERNAME,
ATTR_USE_CANDIDATE,
Expand All @@ -31,13 +33,19 @@ use stun::xoraddr::XorMappedAddress;
pub struct GatewayHandler {
server_states: Rc<RefCell<ServerStates>>,
transmits: VecDeque<TaggedMessageEvent>,
next_timeout: Instant,
idle_timeout: Duration,
}

impl GatewayHandler {
pub fn new(server_states: Rc<RefCell<ServerStates>>) -> Self {
let idle_timeout = server_states.borrow().server_config().idle_timeout;

GatewayHandler {
server_states,
transmits: VecDeque::new(),
next_timeout: Instant::now().add(idle_timeout),
idle_timeout,
}
}
}
Expand All @@ -52,6 +60,23 @@ impl Handler for GatewayHandler {
"GatewayHandler"
}

fn transport_inactive(&mut self, _ctx: &Context<Self::Rin, Self::Rout, Self::Win, Self::Wout>) {
let server_states = self.server_states.borrow();
let sessions = server_states.get_sessions();
let mut endpoint_count = 0;
for session in sessions.values() {
endpoint_count += session.get_endpoints().len();
}
info!(
"Still Active Sessions {}, Endpoints {}/{}, Candidates {} on {}",
sessions.len(),
endpoint_count,
server_states.get_endpoints().len(),
server_states.get_candidates().len(),
server_states.local_addr()
);
}

fn handle_read(
&mut self,
ctx: &Context<Self::Rin, Self::Rout, Self::Win, Self::Wout>,
Expand Down Expand Up @@ -115,9 +140,34 @@ impl Handler for GatewayHandler {
fn handle_timeout(
&mut self,
_ctx: &Context<Self::Rin, Self::Rout, Self::Win, Self::Wout>,
_now: Instant,
now: Instant,
) {
// terminate timeout here, no more ctx.fire_handle_timeout(now);
if self.next_timeout <= now {
let mut four_tuples = vec![];
let mut server_states = self.server_states.borrow_mut();
for session in server_states.get_mut_sessions().values_mut() {
for endpoint in session.get_mut_endpoints().values_mut() {
for transport in endpoint.get_mut_transports().values_mut() {
if transport.last_activity() <= now.sub(self.idle_timeout) {
four_tuples.push(*transport.four_tuple());
}
}
}
}
for four_tuple in four_tuples {
server_states.remove_transport(four_tuple);
}

self.next_timeout = self.next_timeout.add(self.idle_timeout);
}
}

fn poll_timeout(&mut self, ctx: &Context<Self::Rin, Self::Rout, Self::Win, Self::Wout>, eto: &mut Instant) {
if self.next_timeout < *eto {
*eto = self.next_timeout;
}
ctx.fire_poll_timeout(eto);
}

fn poll_write(
Expand All @@ -127,7 +177,6 @@ impl Handler for GatewayHandler {
if let Some(msg) = ctx.fire_poll_write() {
self.transmits.push_back(msg);
}

self.transmits.pop_front()
}
}
Expand Down Expand Up @@ -387,6 +436,7 @@ impl GatewayHandler {
rtp_packet: rtp::packet::Packet,
) -> Result<Vec<TaggedMessageEvent>> {
debug!("handle_rtp_message {}", transport_context.peer_addr);
server_states.get_mut_transport(&(&transport_context).into())?.keep_alive();

//TODO: Selective Forwarding RTP Packets
let peers =
Expand All @@ -411,6 +461,7 @@ impl GatewayHandler {
rtcp_packets: Vec<Box<dyn rtcp::packet::Packet>>,
) -> Result<Vec<TaggedMessageEvent>> {
debug!("handle_rtcp_message {}", transport_context.peer_addr);
server_states.get_mut_transport(&(&transport_context).into())?.keep_alive();

//TODO: Selective Forwarding RTCP Packets
let peers =
Expand Down
2 changes: 1 addition & 1 deletion src/handler/stun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub struct StunHandler;

impl StunHandler {
pub fn new() -> Self {
StunHandler::default()
StunHandler
}
}

Expand Down
18 changes: 5 additions & 13 deletions src/server/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ pub struct ServerConfig {
pub(crate) sctp_endpoint_config: Arc<sctp::EndpointConfig>,
pub(crate) sctp_server_config: Arc<sctp::ServerConfig>,
pub(crate) media_config: MediaConfig,
pub(crate) endpoint_idle_timeout: Duration,
pub(crate) candidate_idle_timeout: Duration,
pub(crate) idle_timeout: Duration,
}

impl ServerConfig {
Expand All @@ -23,8 +22,7 @@ impl ServerConfig {
sctp_endpoint_config: Arc::new(sctp::EndpointConfig::default()),
sctp_server_config: Arc::new(sctp::ServerConfig::default()),
dtls_handshake_config: Arc::new(dtls::config::HandshakeConfig::default()),
endpoint_idle_timeout: Duration::from_secs(30),
candidate_idle_timeout: Duration::from_secs(30),
idle_timeout: Duration::from_secs(30),
}
}

Expand Down Expand Up @@ -58,15 +56,9 @@ impl ServerConfig {
self
}

/// build with endpoint idle timeout
pub fn with_endpoint_idle_timeout(mut self, endpoint_idle_timeout: Duration) -> Self {
self.endpoint_idle_timeout = endpoint_idle_timeout;
self
}

/// build with candidate idle timeout
pub fn with_candidate_idle_timeout(mut self, candidate_idle_timeout: Duration) -> Self {
self.candidate_idle_timeout = candidate_idle_timeout;
/// build with idle timeout
pub fn with_idle_timeout(mut self, idle_timeout: Duration) -> Self {
self.idle_timeout = idle_timeout;
self
}
}
42 changes: 40 additions & 2 deletions src/server/states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::endpoint::{
use crate::server::config::ServerConfig;
use crate::session::{config::SessionConfig, Session};
use crate::types::{EndpointId, FourTuple, SessionId, UserName};
use log::info;
use log::{debug, info};
use shared::error::{Error, Result};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
Expand Down Expand Up @@ -101,7 +101,7 @@ impl ServerStates {
local_conn_cred,
offer,
answer.clone(),
Instant::now() + self.server_config.candidate_idle_timeout,
Instant::now() + self.server_config.idle_timeout,
)));
}

Expand Down Expand Up @@ -162,6 +162,10 @@ impl ServerStates {
self.sessions.get_mut(session_id)
}

pub(crate) fn remove_session(&mut self, session_id: &SessionId) -> Option<Session> {
self.sessions.remove(session_id)
}

pub(crate) fn add_candidate(&mut self, candidate: Rc<Candidate>) -> Option<Rc<Candidate>> {
let username = candidate.username();
self.candidates.insert(username, candidate)
Expand All @@ -175,6 +179,14 @@ impl ServerStates {
self.candidates.get(username)
}

pub(crate) fn get_candidates(&self) -> &HashMap<UserName, Rc<Candidate>> {
&self.candidates
}

pub(crate) fn get_endpoints(&self) -> &HashMap<FourTuple, (SessionId, EndpointId)> {
&self.endpoints
}

pub(crate) fn add_endpoint(
&mut self,
four_tuple: FourTuple,
Expand Down Expand Up @@ -242,4 +254,30 @@ impl ServerStates {

Ok(transport)
}

pub(crate) fn remove_transport(&mut self, four_tuple: FourTuple) {
debug!("remove idle transport {:?}", four_tuple);

let Some((session_id, endpoint_id)) = self.find_endpoint(&four_tuple) else {
return;
};
let Some(session) = self.get_mut_session(&session_id) else {
return;
};
let Some(endpoint) = session.get_mut_endpoint(&endpoint_id) else {
return;
};

let transport = endpoint.remove_transport(&four_tuple);
if endpoint.get_transports().is_empty() {
session.remove_endpoint(&endpoint_id);
if session.get_endpoints().is_empty() {
self.remove_session(&session_id);
}
self.remove_endpoint(&four_tuple);
}
if let Some(transport) = transport {
self.remove_candidate(&transport.candidate().username());
}
}
}
4 changes: 4 additions & 0 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ impl Session {
self.endpoints.get_mut(endpoint_id)
}

pub(crate) fn remove_endpoint(&mut self, endpoint_id: &EndpointId) -> Option<Endpoint> {
self.endpoints.remove(endpoint_id)
}

pub(crate) fn has_endpoint(&self, endpoint_id: &EndpointId) -> bool {
self.endpoints.contains_key(endpoint_id)
}
Expand Down

0 comments on commit e01be38

Please sign in to comment.