diff --git a/examples/chat.rs b/examples/chat.rs index c8a46e6..f8890af 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -7,6 +7,7 @@ use retty::bootstrap::BootstrapUdpServer; use retty::channel::Pipeline; use retty::executor::LocalExecutorBuilder; use retty::transport::{AsyncTransport, AsyncTransportWrite, TaggedBytesMut}; +use sfu::handlers::data::DataChannelHandler; use sfu::handlers::demuxer::DemuxerHandler; use sfu::handlers::dtls::DtlsHandler; use sfu::handlers::gateway::GatewayHandler; @@ -149,6 +150,7 @@ fn main() -> anyhow::Result<()> { let stun_handler = StunHandler::new(); let dtls_handler = DtlsHandler::new(Rc::clone(&server_states_moved), dtls_handshake_config_moved.clone()); let sctp_handler = SctpHandler::new(Rc::clone(&server_states_moved), sctp_endpoint_config_moved.clone()); + let data_channel_handler = DataChannelHandler::new(); //TODO: add DTLS and RTP handlers let gateway_handler = GatewayHandler::new(Rc::clone(&server_states_moved)); @@ -157,6 +159,7 @@ fn main() -> anyhow::Result<()> { pipeline.add_back(stun_handler); pipeline.add_back(dtls_handler); pipeline.add_back(sctp_handler); + pipeline.add_back(data_channel_handler); //TODO: add DTLS and RTP handlers pipeline.add_back(gateway_handler); diff --git a/rtc b/rtc index cb8d7ee..7d58b1d 160000 --- a/rtc +++ b/rtc @@ -1 +1 @@ -Subproject commit cb8d7ee0573e1cd7d0fe73b9607007cb9d1bc2de +Subproject commit 7d58b1da5e48cc758c7b7da9cb70c28d059e7b52 diff --git a/src/handlers/data/mod.rs b/src/handlers/data/mod.rs index 8b13789..4c62732 100644 --- a/src/handlers/data/mod.rs +++ b/src/handlers/data/mod.rs @@ -1 +1,175 @@ +use crate::messages::{ + ApplicationMessage, DTLSMessageEvent, DataChannelMessage, DataChannelMessageParams, + DataChannelMessageType, MessageEvent, TaggedMessageEvent, +}; +use data::message::{message_channel_ack::*, message_channel_open::*, message_type::*, *}; +use log::debug; +use retty::channel::{Handler, InboundContext, InboundHandler, OutboundContext, OutboundHandler}; +use shared::error::{Error, Result}; +use shared::marshal::*; +#[derive(Default)] +struct DataChannelInbound; +#[derive(Default)] +struct DataChannelOutbound; +#[derive(Default)] +pub struct DataChannelHandler { + data_channel_inbound: DataChannelInbound, + data_channel_outbound: DataChannelOutbound, +} + +impl DataChannelHandler { + pub fn new() -> Self { + DataChannelHandler::default() + } +} + +impl InboundHandler for DataChannelInbound { + type Rin = TaggedMessageEvent; + type Rout = Self::Rin; + + fn read(&mut self, ctx: &InboundContext, msg: Self::Rin) { + if let MessageEvent::DTLS(DTLSMessageEvent::SCTP(message)) = msg.message { + debug!( + "recv SCTP DataChannelMessage {:?} with {:?}", + msg.transport.peer_addr, message + ); + let try_read = + || -> Result<(Option, Option)> { + if message.data_message_type == DataChannelMessageType::Control { + let mut buf = &message.payload[..]; + if MessageType::unmarshal(&mut buf)? == MessageType::DataChannelOpen { + debug!("DataChannelOpen for association_handle {} and stream_id {} and data_message_type {:?}", + message.association_handle, + message.stream_id, + message.data_message_type); + + let _ = DataChannelOpen::unmarshal(&mut buf)?; + + let payload = Message::DataChannelAck(DataChannelAck {}).marshal()?; + Ok(( + None, + Some(DataChannelMessage { + association_handle: message.association_handle, + stream_id: message.stream_id, + data_message_type: DataChannelMessageType::Control, + params: DataChannelMessageParams::Outbound { + ordered: true, + reliable: true, + max_rtx_count: 0, + max_rtx_millis: 0, + }, + payload, + }), + )) + } else { + Ok((None, None)) + } + } else if message.data_message_type == DataChannelMessageType::Binary { + Ok(( + Some(ApplicationMessage { + association_handle: message.association_handle, + stream_id: message.stream_id, + payload: message.payload, + }), + None, + )) + } else { + Err(Error::UnknownProtocol) + } + }; + + match try_read() { + Ok((inbound_message, outbound_message)) => { + if let Some(application_message) = inbound_message { + debug!("recv application message {:?}", msg.transport.peer_addr); + ctx.fire_read(TaggedMessageEvent { + now: msg.now, + transport: msg.transport, + message: MessageEvent::DTLS(DTLSMessageEvent::APPLICATION( + application_message, + )), + }) + } + if let Some(data_channel_message) = outbound_message { + debug!("send DataChannelAck message {:?}", msg.transport.peer_addr); + ctx.fire_write(TaggedMessageEvent { + now: msg.now, + transport: msg.transport, + message: MessageEvent::DTLS(DTLSMessageEvent::SCTP( + data_channel_message, + )), + }); + } + } + Err(err) => ctx.fire_read_exception(Box::new(err)), + }; + } else { + // Bypass + debug!("bypass DataChannel read {:?}", msg.transport.peer_addr); + ctx.fire_read(msg); + } + } +} + +impl OutboundHandler for DataChannelOutbound { + type Win = TaggedMessageEvent; + type Wout = Self::Win; + + fn write(&mut self, ctx: &OutboundContext, msg: Self::Win) { + if let MessageEvent::DTLS(DTLSMessageEvent::APPLICATION(message)) = msg.message { + debug!( + "send application message {:?} with {:?}", + msg.transport.peer_addr, message + ); + + ctx.fire_write(TaggedMessageEvent { + now: msg.now, + transport: msg.transport, + message: MessageEvent::DTLS(DTLSMessageEvent::SCTP(DataChannelMessage { + association_handle: message.association_handle, + stream_id: message.stream_id, + data_message_type: DataChannelMessageType::Binary, + params: DataChannelMessageParams::Outbound { + ordered: true, + reliable: true, + max_rtx_count: 0, + max_rtx_millis: 0, + }, + payload: message.payload, + })), + }); + } else { + // Bypass + debug!("bypass DataChannel write {:?}", msg.transport.peer_addr); + ctx.fire_write(msg); + } + } + + fn close(&mut self, ctx: &OutboundContext) { + ctx.fire_close(); + } +} + +impl Handler for DataChannelHandler { + type Rin = TaggedMessageEvent; + type Rout = Self::Rin; + type Win = TaggedMessageEvent; + type Wout = Self::Win; + + fn name(&self) -> &str { + "DataChannelHandler" + } + + fn split( + self, + ) -> ( + Box>, + Box>, + ) { + ( + Box::new(self.data_channel_inbound), + Box::new(self.data_channel_outbound), + ) + } +} diff --git a/src/handlers/sctp/mod.rs b/src/handlers/sctp/mod.rs index 0a3c5ae..5951d50 100644 --- a/src/handlers/sctp/mod.rs +++ b/src/handlers/sctp/mod.rs @@ -126,7 +126,9 @@ impl InboundHandler for SctpInbound { association_handle: ch.0, stream_id: id, data_message_type: to_data_message_type(chunks.ppi), - params: DataChannelMessageParams::Inbound { seq_num: 0 }, + params: DataChannelMessageParams::Inbound { + seq_num: chunks.ssn, + }, payload: BytesMut::from(&self.internal_buffer[0..n]), }); } diff --git a/src/messages.rs b/src/messages.rs index fa1f4b6..80df491 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -32,6 +32,13 @@ pub struct DataChannelMessage { pub(crate) payload: BytesMut, } +#[derive(Debug)] +pub struct ApplicationMessage { + pub(crate) association_handle: usize, + pub(crate) stream_id: u16, + pub(crate) payload: BytesMut, +} + #[derive(Debug)] pub enum STUNMessageEvent { RAW(BytesMut), @@ -42,7 +49,7 @@ pub enum STUNMessageEvent { pub enum DTLSMessageEvent { RAW(BytesMut), SCTP(DataChannelMessage), - APPLICATION(BytesMut), + APPLICATION(ApplicationMessage), } #[derive(Debug)]