From 1444b3c0635ab4a5117494990a4d2b35b54eff79 Mon Sep 17 00:00:00 2001 From: Jonas Herzig Date: Sun, 5 Apr 2020 13:36:15 +0200 Subject: [PATCH] Update rfc5764 to use futures 0.3 AsyncRead/Write With futures 0.3, the AsyncRead/Write traits are no longer marker traits, so we can no longer easily support both, non-blocking and blocking APIs. The code can however still be used in a blocking way buy just calling `now_or_never` on any futures it returns, which will block if the underlying streams are blocking. --- Cargo.toml | 11 +- examples/srtpsrv.rs | 2 +- src/rfc5764/mod.rs | 458 ++++++++++++++++++++++++----------------- src/rfc5764/openssl.rs | 162 +++++---------- src/rfc5764/tokio.rs | 95 --------- 5 files changed, 333 insertions(+), 395 deletions(-) delete mode 100644 src/rfc5764/tokio.rs diff --git a/Cargo.toml b/Cargo.toml index b293084..33cf931 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,17 +4,24 @@ version = "0.1.0" authors = ["Takeru Ohta "] edition = "2018" +[features] +default = [] +rfc5764-openssl = ["openssl", "tokio-openssl", "tokio-util/compat"] + [dependencies] trackable = "0.1" handy_async = "0.2" rust-crypto = "0.2" num = "0.1" fixedbitset = "0.1" +futures = "0.3" +async-trait = "0.1" openssl = { version = "0.10", optional = true } -tokio = { version = "0.1", optional = true } +tokio-openssl = { version = "0.4", optional = true } +tokio-util = { version = "0.3", optional = true, default-features = false } [dev-dependencies] clap = "2" fibers = "0.1" -futures = "0.1" +futures01 = { package = "futures", version = "0.1" } diff --git a/examples/srtpsrv.rs b/examples/srtpsrv.rs index 1d3e083..26461e6 100644 --- a/examples/srtpsrv.rs +++ b/examples/srtpsrv.rs @@ -1,6 +1,6 @@ extern crate clap; extern crate fibers; -extern crate futures; +extern crate futures01 as futures; #[macro_use] extern crate trackable; extern crate rtp; diff --git a/src/rfc5764/mod.rs b/src/rfc5764/mod.rs index cd05c00..7b29e94 100644 --- a/src/rfc5764/mod.rs +++ b/src/rfc5764/mod.rs @@ -1,18 +1,24 @@ // FIXME: the current SRTP implementation does not support the maximum_lifetime parameter -#[cfg(feature = "openssl")] +#[cfg(feature = "rfc5764-openssl")] mod openssl; -#[cfg(feature = "tokio")] -mod tokio; - +use async_trait::async_trait; +use futures::io::{AsyncRead, AsyncWrite}; +use futures::ready; +use futures::{Sink, Stream}; use std::collections::VecDeque; use std::io; -use std::io::{Read, Write}; +use std::io::Read; +use std::pin::Pin; use std::sync::Arc; use std::sync::Mutex; +use std::task::Context; +use std::task::Poll; -use crate::rfc3711::{AuthenticationAlgorithm, Context, EncryptionAlgorithm, Srtcp, Srtp}; +use crate::rfc3711::{ + AuthenticationAlgorithm, Context as SrtpContext, EncryptionAlgorithm, Srtcp, Srtp, +}; use crate::types::Ssrc; #[derive(Debug, Clone, PartialEq, Eq)] @@ -48,28 +54,16 @@ impl SrtpProtectionProfile { &[&SrtpProtectionProfile::AES128_CM_HMAC_SHA1_80]; } -pub enum DtlsHandshakeResult { - Failure(io::Error), - WouldBlock(DtlsMidHandshake), - Success(Dtls), -} - +#[async_trait] pub trait DtlsBuilder { type Instance: Dtls; - type MidHandshake: DtlsMidHandshake; - fn handshake(self, stream: S) -> DtlsHandshakeResult + async fn handshake(self, stream: S) -> Result where - S: Read + Write; + S: AsyncRead + AsyncWrite + Unpin + 'async_trait; } -pub trait DtlsMidHandshake: Sized { - type Instance: Dtls; - - fn handshake(self) -> DtlsHandshakeResult; -} - -pub trait Dtls: Read + Write { +pub trait Dtls: AsyncRead + AsyncWrite + Unpin { fn is_server_side(&self) -> bool; fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec; } @@ -80,7 +74,7 @@ pub struct DtlsSrtpMuxer { srtp_buf: VecDeque>, } -impl DtlsSrtpMuxer { +impl DtlsSrtpMuxer { fn new(inner: S) -> Self { DtlsSrtpMuxer { inner, @@ -102,8 +96,13 @@ impl DtlsSrtpMuxer { } } -impl DtlsSrtpMuxer { - fn read(&mut self, want_srtp: bool, dst_buf: &mut [u8]) -> io::Result { +impl DtlsSrtpMuxer { + fn read( + &mut self, + cx: &mut Context, + want_srtp: bool, + dst_buf: &mut [u8], + ) -> Poll> { { let want_buf = if want_srtp { &mut self.srtp_buf @@ -111,20 +110,20 @@ impl DtlsSrtpMuxer { &mut self.dtls_buf }; if let Some(buf) = want_buf.pop_front() { - return (&buf[..]).read(dst_buf); + return Poll::Ready((&buf[..]).read(dst_buf)); } } let mut buf = [0u8; 2048]; - let len = self.inner.read(&mut buf)?; + let len = ready!(Pin::new(&mut self.inner).poll_read(cx, &mut buf))?; if len == 0 { - return Ok(0); + return Poll::Ready(Ok(0)); } let mut buf = &buf[..len]; // Demux SRTP and DTLS as per https://tools.ietf.org/html/rfc5764#section-5.1.2 let is_srtp = buf[0] >= 128 && buf[0] <= 191; if is_srtp == want_srtp { - buf.read(dst_buf) + Poll::Ready(buf.read(dst_buf)) } else { if is_srtp { &mut self.srtp_buf @@ -138,7 +137,7 @@ impl DtlsSrtpMuxer { // by pretending that we're doing non-blocking io (even if we aren't) // to get back to where we can enter the other (in the example: the dtls) // read-path and process the packet we just read. - Err(io::Error::new(io::ErrorKind::WouldBlock, "")) + Poll::Pending // FIXME this doesn't see sound, shouldn't we store the waker!? } } } @@ -148,86 +147,61 @@ pub struct DtlsSrtpMuxerPart { srtp: bool, } -impl Read for DtlsSrtpMuxerPart +impl AsyncRead for DtlsSrtpMuxerPart where - S: Read, + S: AsyncRead + Unpin, { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.muxer.lock().unwrap().read(self.srtp, buf) + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + self.muxer.lock().unwrap().read(cx, self.srtp, buf) } } -impl Write for DtlsSrtpMuxerPart +impl AsyncWrite for DtlsSrtpMuxerPart where - S: Write, + S: AsyncWrite + Unpin, { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.muxer.lock().unwrap().inner.write(buf) + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.muxer.lock().unwrap().inner).poll_write(cx, buf) } - fn flush(&mut self) -> io::Result<()> { - self.muxer.lock().unwrap().inner.flush() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.muxer.lock().unwrap().inner).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + Pin::new(&mut self.muxer.lock().unwrap().inner).poll_close(cx) } } -pub enum DtlsSrtpHandshakeResult>> { - Success(DtlsSrtp), - WouldBlock(DtlsSrtpMidHandshake), - Failure(io::Error), -} - -pub struct DtlsSrtpMidHandshake>> { - stream: DtlsSrtpMuxerPart, - dtls: D::MidHandshake, -} - -pub struct DtlsSrtp>> { +pub struct DtlsSrtp>> { stream: DtlsSrtpMuxerPart, + #[allow(dead_code)] // we'll need this once we implement re-keying dtls: D::Instance, - srtp_read_context: Context, - srtcp_read_context: Context, - srtp_write_context: Context, - srtcp_write_context: Context, -} - -impl DtlsSrtpMidHandshake -where - S: Read + Write + Sized, - D: DtlsBuilder>, -{ - pub fn handshake(mut self) -> DtlsSrtpHandshakeResult { - match self.dtls.handshake() { - DtlsHandshakeResult::Success(dtls) => { - DtlsSrtpHandshakeResult::Success(DtlsSrtp::new(self.stream, dtls)) - } - DtlsHandshakeResult::WouldBlock(dtls) => { - self.dtls = dtls; - DtlsSrtpHandshakeResult::WouldBlock(self) - } - DtlsHandshakeResult::Failure(err) => DtlsSrtpHandshakeResult::Failure(err), - } - } + srtp_read_context: SrtpContext, + srtcp_read_context: SrtpContext, + srtp_write_context: SrtpContext, + srtcp_write_context: SrtpContext, + sink_buf: Option>, } impl DtlsSrtp where - S: Read + Write, + S: AsyncRead + AsyncWrite + Unpin, D: DtlsBuilder>, { - pub fn handshake(stream: S, dtls_builder: D) -> DtlsSrtpHandshakeResult { + pub async fn handshake(stream: S, dtls_builder: D) -> Result, io::Error> { let (stream_dtls, stream_srtp) = DtlsSrtpMuxer::new(stream).into_parts(); - match dtls_builder.handshake(stream_dtls) { - DtlsHandshakeResult::Success(dtls) => { - DtlsSrtpHandshakeResult::Success(DtlsSrtp::new(stream_srtp, dtls)) - } - DtlsHandshakeResult::WouldBlock(dtls) => { - DtlsSrtpHandshakeResult::WouldBlock(DtlsSrtpMidHandshake { - stream: stream_srtp, - dtls, - }) - } - DtlsHandshakeResult::Failure(err) => DtlsSrtpHandshakeResult::Failure(err), - } + let dtls = dtls_builder.handshake(stream_dtls).await?; + Ok(DtlsSrtp::new(stream_srtp, dtls)) } fn new(stream: DtlsSrtpMuxerPart, mut dtls: D::Instance) -> Self { @@ -255,10 +229,11 @@ where DtlsSrtp { stream, dtls, - srtp_read_context: Context::new(&read_key, &read_salt), - srtcp_read_context: Context::new(&read_key, &read_salt), - srtp_write_context: Context::new(&write_key, &write_salt), - srtcp_write_context: Context::new(&write_key, &write_salt), + srtp_read_context: SrtpContext::new(&read_key, &read_salt), + srtcp_read_context: SrtpContext::new(&read_key, &read_salt), + srtp_write_context: SrtpContext::new(&write_key, &write_salt), + srtcp_write_context: SrtpContext::new(&write_key, &write_salt), + sink_buf: None, } } @@ -303,12 +278,33 @@ where } } -impl Read for DtlsSrtp +impl AsyncRead for DtlsSrtp where - S: Read + Write, + S: AsyncRead + AsyncWrite + Unpin, D: DtlsBuilder>, { - fn read(&mut self, buf: &mut [u8]) -> io::Result { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + let item = ready!(self.poll_next(cx)?); + if let Some(item) = item { + Poll::Ready((&item[..]).read(buf)) + } else { + Poll::Ready(Ok(0)) + } + } +} + +impl Stream for DtlsSrtp +where + S: AsyncRead + AsyncWrite + Unpin, + D: DtlsBuilder>, +{ + type Item = io::Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { loop { // Check if we have an SRTP packet in the queue if self.stream.muxer.lock().unwrap().srtp_buf.is_empty() { @@ -331,13 +327,13 @@ where // Read and handle the next SRTP packet from the queue let mut raw_buf = [0u8; 2048]; - let len = self.stream.read(&mut raw_buf)?; + let len = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut raw_buf))?; if len == 0 { - return Ok(0); + return Poll::Ready(None); } let raw_buf = &raw_buf[..len]; return match self.process_incoming_srtp_packet(raw_buf) { - Some(result) => (&result[..]).read(buf), + Some(result) => Poll::Ready(Some(Ok(result))), None => { // FIXME: check rfc for whether this should be dropped silently continue; // packet failed to auth or decrypt, drop it and try the next one @@ -347,21 +343,68 @@ where } } -impl Write for DtlsSrtp +impl AsyncWrite for DtlsSrtp where - S: Read + Write, + S: AsyncRead + AsyncWrite + Unpin, D: DtlsBuilder>, { - fn write(&mut self, buf: &[u8]) -> io::Result { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { if let Some(buf) = self.process_outgoing_srtp_packet(buf) { - self.stream.write(&buf) + Pin::new(&mut self.stream).poll_write(cx, &buf) } else { - Ok(buf.len()) + Poll::Ready(Ok(buf.len())) } } - fn flush(&mut self) -> io::Result<()> { - self.stream.flush() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream).poll_close(cx) + } +} + +impl Sink<&[u8]> for DtlsSrtp +where + S: AsyncRead + AsyncWrite + Unpin, + D: DtlsBuilder>, +{ + type Error = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let _ = Sink::poll_flush(self.as_mut(), cx)?; + if self.sink_buf.is_none() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + fn start_send(mut self: Pin<&mut Self>, item: &[u8]) -> io::Result<()> { + self.sink_buf = self.process_outgoing_srtp_packet(item.as_ref()); + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + if let Some(buf) = self.sink_buf.take() { + match Pin::new(&mut self.stream).poll_write(cx, &buf) { + Poll::Pending => { + self.sink_buf = Some(buf); + return Poll::Pending; + } + _ => {} + } + } + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream).poll_close(cx) } } @@ -373,6 +416,28 @@ pub(crate) mod test { TEST_SRTP_SSRC, }; + use futures::{AsyncReadExt, AsyncWriteExt, FutureExt}; + + macro_rules! read_now { + ( $expr:expr, $buf:expr ) => { + $expr + .read($buf) + .now_or_never() + .expect("would block") + .expect("reading") + }; + } + + macro_rules! write_now { + ( $expr:expr, $buf:expr ) => { + $expr + .write($buf) + .now_or_never() + .expect("would block") + .expect("writing") + }; + } + struct DummyDtlsBuilder; struct DummyDtls { connected: bool, @@ -388,38 +453,31 @@ pub(crate) mod test { DummyDtlsBuilder {} } } - impl DtlsBuilder for DummyDtlsBuilder { + #[async_trait] + impl DtlsBuilder for DummyDtlsBuilder { type Instance = DummyDtls; - type MidHandshake = DummyDtls; - fn handshake( - self, - mut stream: S, - ) -> DtlsHandshakeResult { - stream.write(DUMMY_DTLS_HELLO).unwrap(); - DummyDtls { + async fn handshake(self, mut stream: S) -> Result + where + S: 'async_trait, + { + let _ = stream.write(DUMMY_DTLS_HELLO).await; + let mut dtls = DummyDtls { stream, connected: false, - } - .handshake() - } - } - impl DtlsMidHandshake for DummyDtls { - type Instance = Self; - fn handshake(mut self) -> DtlsHandshakeResult { - let result = self.read(&mut []).unwrap_err(); - if result.kind() == io::ErrorKind::WouldBlock { - if self.connected { - DtlsHandshakeResult::Success(self) + }; + loop { + let _ = futures::poll!(dtls.read(&mut [])); + if dtls.connected { + break; } else { - DtlsHandshakeResult::WouldBlock(self) + futures::pending!(); } - } else { - DtlsHandshakeResult::Failure(result) } + Ok(dtls) } } - impl Dtls for DummyDtls { + impl Dtls for DummyDtls { fn is_server_side(&self) -> bool { true } @@ -440,17 +498,21 @@ pub(crate) mod test { } } - impl Read for DummyDtls { - fn read(&mut self, _dst: &mut [u8]) -> io::Result { + impl AsyncRead for DummyDtls { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context, + _dst: &mut [u8], + ) -> Poll> { loop { let mut buf = [0u8; 2048]; - let len = self.stream.read(&mut buf)?; + let len = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut buf))?; assert_eq!(len, 2); assert_eq!(buf[1], 42); match &buf[..len] { DUMMY_DTLS_NOOP => {} DUMMY_DTLS_HELLO => { - self.stream.write(DUMMY_DTLS_CONNECTED)?; + let _ = Pin::new(&mut self.stream).poll_write(cx, DUMMY_DTLS_CONNECTED)?; } DUMMY_DTLS_CONNECTED => { self.connected = true; @@ -461,13 +523,21 @@ pub(crate) mod test { } } - impl Write for DummyDtls { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.stream.write(buf) + impl AsyncWrite for DummyDtls { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.stream).poll_write(cx, buf) } - fn flush(&mut self) -> io::Result<()> { - self.stream.flush() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) } } @@ -498,76 +568,82 @@ pub(crate) mod test { } } - impl Read for DummyTransport { - fn read(&mut self, buf: &mut [u8]) -> io::Result { + impl AsyncRead for DummyTransport { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { match self.read_buf.lock().unwrap().pop_front() { - None => Err(io::Error::new(io::ErrorKind::WouldBlock, "")), - Some(elem) => (&mut &elem[..]).read(buf), + None => Poll::Pending, + Some(elem) => Poll::Ready(std::io::Read::read(&mut &elem[..], buf)), } } } - impl Write for DummyTransport { - fn write(&mut self, buf: &[u8]) -> io::Result { + impl AsyncWrite for DummyTransport { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context, + buf: &[u8], + ) -> Poll> { self.write_buf.lock().unwrap().push_back(buf.to_vec()); - Ok(buf.len()) + Poll::Ready(Ok(buf.len())) } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + Poll::Ready(Ok(())) } - } - macro_rules! assert_wouldblock { - ( $expr:expr ) => { - let err = $expr.unwrap_err(); - assert_eq!(err.kind(), io::ErrorKind::WouldBlock); - }; + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } } fn new_dtls_srtp() -> (DummyTransport, DtlsSrtp) { let (mut stream, dummy_stream) = DummyTransport::new(); - stream.write(DUMMY_DTLS_CONNECTED).unwrap(); - match DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()) { - DtlsSrtpHandshakeResult::Success(mut dtls_srtp) => { - assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO); - dtls_srtp.add_incoming_ssrc(TEST_SRTP_SSRC); - dtls_srtp.add_incoming_ssrc(TEST_SRTCP_SSRC); - dtls_srtp.add_outgoing_ssrc(TEST_SRTP_SSRC); - dtls_srtp.add_outgoing_ssrc(TEST_SRTCP_SSRC); - (stream, dtls_srtp) - } - _ => panic!("DTLS-SRTP handshake failed"), - } + write_now!(stream, DUMMY_DTLS_CONNECTED); + let mut dtls_srtp = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()) + .now_or_never() + .expect("DTLS-SRTP handshake did not complete") + .expect("DTL-SRTP handshake failed"); + assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO); + dtls_srtp.add_incoming_ssrc(TEST_SRTP_SSRC); + dtls_srtp.add_incoming_ssrc(TEST_SRTCP_SSRC); + dtls_srtp.add_outgoing_ssrc(TEST_SRTP_SSRC); + dtls_srtp.add_outgoing_ssrc(TEST_SRTCP_SSRC); + (stream, dtls_srtp) } #[test] fn polls_dtls_layer_for_keys() { + let mut cx = Context::from_waker(futures::task::noop_waker_ref()); let (mut stream, dummy_stream) = DummyTransport::new(); - let handshake = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()); - let handshake = match handshake { - DtlsSrtpHandshakeResult::WouldBlock(it) => it, + let mut handshake = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()).boxed(); + match handshake.as_mut().poll(&mut cx) { + Poll::Pending => {} _ => panic!(), }; - stream.write(TEST_SRTP_PACKET).unwrap(); // too early, should be discarded + // too early, should be discarded + write_now!(stream, TEST_SRTP_PACKET); - let handshake = match handshake.handshake() { - DtlsSrtpHandshakeResult::WouldBlock(it) => it, + match handshake.as_mut().poll(&mut cx) { + Poll::Pending => {} _ => panic!(), }; assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO); - stream.write(DUMMY_DTLS_HELLO).unwrap(); - let handshake = match handshake.handshake() { - DtlsSrtpHandshakeResult::WouldBlock(it) => it, + write_now!(stream, DUMMY_DTLS_HELLO); + match handshake.as_mut().poll(&mut cx) { + Poll::Pending => {} _ => panic!(), }; assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_CONNECTED); - stream.write(DUMMY_DTLS_CONNECTED).unwrap(); - match handshake.handshake() { - DtlsSrtpHandshakeResult::Success(_) => {} + write_now!(stream, DUMMY_DTLS_CONNECTED); + match handshake.as_mut().poll(&mut cx) { + Poll::Ready(_) => {} _ => panic!(), }; } @@ -577,10 +653,10 @@ pub(crate) mod test { let mut buf = [0u8; 2048]; let (mut stream, mut dtls_srtp) = new_dtls_srtp(); - stream.write(TEST_SRTP_PACKET).unwrap(); - stream.write(TEST_SRTCP_PACKET).unwrap(); - assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 182); // srtp - assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 68); // srtcp + write_now!(stream, TEST_SRTP_PACKET); + write_now!(stream, TEST_SRTCP_PACKET); + assert_eq!(read_now!(dtls_srtp, &mut buf), 182); // srtp + assert_eq!(read_now!(dtls_srtp, &mut buf), 68); // srtcp } #[test] @@ -588,16 +664,16 @@ pub(crate) mod test { let mut buf = [0u8; 2048]; let (mut stream, mut dtls_srtp) = new_dtls_srtp(); - stream.write(TEST_SRTP_PACKET).unwrap(); - stream.write(TEST_SRTP_PACKET).unwrap(); - stream.write(TEST_SRTP_PACKET).unwrap(); - assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 182); - assert_wouldblock!(dtls_srtp.read(&mut buf)); + write_now!(stream, TEST_SRTP_PACKET); + write_now!(stream, TEST_SRTP_PACKET); + write_now!(stream, TEST_SRTP_PACKET); + assert_eq!(read_now!(dtls_srtp, &mut buf), 182); + assert!(dtls_srtp.read(&mut buf).now_or_never().is_none(),); - stream.write(TEST_SRTCP_PACKET).unwrap(); - stream.write(TEST_SRTCP_PACKET).unwrap(); - stream.write(TEST_SRTCP_PACKET).unwrap(); - assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 68); - assert_wouldblock!(dtls_srtp.read(&mut buf)); + write_now!(stream, TEST_SRTCP_PACKET); + write_now!(stream, TEST_SRTCP_PACKET); + write_now!(stream, TEST_SRTCP_PACKET); + assert_eq!(read_now!(dtls_srtp, &mut buf), 68); + assert!(dtls_srtp.read(&mut buf).now_or_never().is_none(),); } } diff --git a/src/rfc5764/openssl.rs b/src/rfc5764/openssl.rs index ec40223..4001423 100644 --- a/src/rfc5764/openssl.rs +++ b/src/rfc5764/openssl.rs @@ -1,19 +1,22 @@ +use async_trait::async_trait; +use futures::io::{AsyncRead, AsyncWrite}; use std::io; -use std::io::{Read, Write}; +use tokio_openssl::SslStream; +use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, Tokio02AsyncReadCompatExt}; -use openssl::ssl::{ - HandshakeError, MidHandshakeSslStream, SslAcceptorBuilder, SslConnectorBuilder, SslStream, -}; +use openssl::ssl::{ConnectConfiguration, SslAcceptorBuilder}; -use crate::rfc5764::{Dtls, DtlsBuilder, DtlsHandshakeResult, DtlsMidHandshake, SrtpProtectionProfile}; +use crate::rfc5764::{Dtls, DtlsBuilder, SrtpProtectionProfile}; -impl DtlsBuilder for SslConnectorBuilder { - type Instance = SslStream; - type MidHandshake = MidHandshakeSslStream; +type CompatSslStream = Compat>>; - fn handshake(mut self, stream: S) -> DtlsHandshakeResult +#[async_trait] +impl DtlsBuilder for ConnectConfiguration { + type Instance = CompatSslStream; + + async fn handshake(mut self, stream: S) -> Result where - S: Read + Write, + S: 'async_trait, { let profiles_str: String = SrtpProtectionProfile::RECOMMENDED .iter() @@ -21,28 +24,20 @@ impl DtlsBuilder for SslConnectorBuilder { .collect::>() .join(":"); self.set_tlsext_use_srtp(&profiles_str).unwrap(); - match self.build().connect("invalid", stream) { - Ok(stream) => DtlsHandshakeResult::Success(stream), - Err(HandshakeError::WouldBlock(mid_handshake)) => { - DtlsHandshakeResult::WouldBlock(mid_handshake) - } - Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure( - io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()), - ), - Err(HandshakeError::SetupFailure(err)) => { - DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err)) - } + match tokio_openssl::connect(self, "invalid", stream.compat()).await { + Ok(stream) => Ok(stream.compat()), + Err(_) => Err(io::Error::new(io::ErrorKind::Other, "handshake error")), } } } -impl DtlsBuilder for SslAcceptorBuilder { - type Instance = SslStream; - type MidHandshake = MidHandshakeSslStream; +#[async_trait] +impl DtlsBuilder for SslAcceptorBuilder { + type Instance = CompatSslStream; - fn handshake(mut self, stream: S) -> DtlsHandshakeResult + async fn handshake(mut self, stream: S) -> Result where - S: Read + Write, + S: 'async_trait, { let profiles_str: String = SrtpProtectionProfile::RECOMMENDED .iter() @@ -50,48 +45,22 @@ impl DtlsBuilder for SslAcceptorBuilder { .collect::>() .join(":"); self.set_tlsext_use_srtp(&profiles_str).unwrap(); - match self.build().accept(stream) { - Ok(stream) => DtlsHandshakeResult::Success(stream), - Err(HandshakeError::WouldBlock(mid_handshake)) => { - DtlsHandshakeResult::WouldBlock(mid_handshake) - } - Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure( - io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()), - ), - Err(HandshakeError::SetupFailure(err)) => { - DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err)) - } + match tokio_openssl::accept(&self.build(), stream.compat()).await { + Ok(stream) => Ok(stream.compat()), + Err(_) => Err(io::Error::new(io::ErrorKind::Other, "handshake error")), } } } -impl DtlsMidHandshake for MidHandshakeSslStream { - type Instance = SslStream; - - fn handshake(self) -> DtlsHandshakeResult { - match MidHandshakeSslStream::handshake(self) { - Ok(stream) => DtlsHandshakeResult::Success(stream), - Err(HandshakeError::WouldBlock(mid_handshake)) => { - DtlsHandshakeResult::WouldBlock(mid_handshake) - } - Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure( - io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()), - ), - Err(HandshakeError::SetupFailure(err)) => { - DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err)) - } - } - } -} - -impl Dtls for SslStream { +impl Dtls for Compat>> { fn is_server_side(&self) -> bool { - self.ssl().is_server() + self.get_ref().ssl().is_server() } fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec { let mut vec = vec![0; length]; - self.ssl() + self.get_mut() + .ssl() .export_keying_material(&mut vec, exporter_label, None) .unwrap(); vec @@ -101,14 +70,16 @@ impl Dtls for SslStream { #[cfg(test)] mod test { use crate::rfc5764::test::DummyTransport; - use crate::rfc5764::{DtlsSrtp, DtlsSrtpHandshakeResult}; + use crate::rfc5764::DtlsSrtp; + use futures::FutureExt; use openssl::asn1::Asn1Time; use openssl::hash::MessageDigest; use openssl::pkey::PKey; use openssl::rsa::Rsa; use openssl::ssl::{SslAcceptor, SslConnector, SslMethod, SslVerifyMode}; use openssl::x509::X509; + use std::task::{Context, Poll}; #[test] fn connect_and_establish_matching_key_material() { @@ -134,54 +105,33 @@ mod test { acceptor.set_private_key(&key).unwrap(); acceptor.set_verify(SslVerifyMode::NONE); connector.set_verify(SslVerifyMode::NONE); - let mut handshake_server = DtlsSrtp::handshake(server_sock, acceptor); - let mut handshake_client = DtlsSrtp::handshake(client_sock, connector); + let handshake_server = DtlsSrtp::handshake(server_sock, acceptor); + let handshake_client = + DtlsSrtp::handshake(client_sock, connector.build().configure().unwrap()); + let mut future = futures::future::join(handshake_server, handshake_client).boxed(); + let mut cx = Context::from_waker(futures::task::noop_waker_ref()); loop { - match (handshake_client, handshake_server) { - (DtlsSrtpHandshakeResult::Failure(err), _) => { - panic!("Client error: {}", err); - } - (_, DtlsSrtpHandshakeResult::Failure(err)) => { - panic!("Server error: {}", err); - } - ( - DtlsSrtpHandshakeResult::Success(client), - DtlsSrtpHandshakeResult::Success(server), - ) => { - assert_eq!( - client.srtp_read_context.master_key, - server.srtp_write_context.master_key - ); - assert_eq!( - client.srtp_read_context.master_salt, - server.srtp_write_context.master_salt - ); - assert_eq!( - client.srtp_write_context.master_key, - server.srtp_read_context.master_key - ); - assert_eq!( - client.srtp_write_context.master_salt, - server.srtp_read_context.master_salt - ); - return; - } - ( - DtlsSrtpHandshakeResult::WouldBlock(client), - DtlsSrtpHandshakeResult::WouldBlock(server), - ) => { - handshake_client = client.handshake(); - handshake_server = server.handshake(); - } - (client, DtlsSrtpHandshakeResult::WouldBlock(server)) => { - handshake_client = client; - handshake_server = server.handshake(); - } - (DtlsSrtpHandshakeResult::WouldBlock(client), server) => { - handshake_client = client.handshake(); - handshake_server = server; - } + if let Poll::Ready((server, client)) = future.as_mut().poll(&mut cx) { + let server = server.unwrap(); + let client = client.unwrap(); + assert_eq!( + client.srtp_read_context.master_key, + server.srtp_write_context.master_key + ); + assert_eq!( + client.srtp_read_context.master_salt, + server.srtp_write_context.master_salt + ); + assert_eq!( + client.srtp_write_context.master_key, + server.srtp_read_context.master_key + ); + assert_eq!( + client.srtp_write_context.master_salt, + server.srtp_read_context.master_salt + ); + return; } } } diff --git a/src/rfc5764/tokio.rs b/src/rfc5764/tokio.rs deleted file mode 100644 index 5088cb1..0000000 --- a/src/rfc5764/tokio.rs +++ /dev/null @@ -1,95 +0,0 @@ -use std::io; -use tokio::prelude::{Async, AsyncRead, AsyncSink, AsyncWrite, Future, Sink, Stream}; - -use crate::rfc5764::{DtlsBuilder, DtlsSrtp, DtlsSrtpHandshakeResult, DtlsSrtpMuxerPart}; - -impl AsyncRead for DtlsSrtp -where - S: AsyncRead + AsyncWrite, - D: DtlsBuilder>, -{ -} - -impl AsyncWrite for DtlsSrtp -where - S: AsyncRead + AsyncWrite, - D: DtlsBuilder>, -{ - fn shutdown(&mut self) -> io::Result> { - Ok(().into()) // FIXME - } -} - -impl Future for DtlsSrtpHandshakeResult -where - S: AsyncRead + AsyncWrite, - D: DtlsBuilder>, -{ - type Item = DtlsSrtp; - type Error = io::Error; - - fn poll(&mut self) -> io::Result> { - let mut owned = DtlsSrtpHandshakeResult::Failure(io::Error::new( - io::ErrorKind::Other, - "poll called after completion", - )); - std::mem::swap(&mut owned, self); - match owned { - DtlsSrtpHandshakeResult::Success(dtls_srtp) => Ok(Async::Ready(dtls_srtp)), - DtlsSrtpHandshakeResult::WouldBlock(mid_handshake) => match mid_handshake.handshake() { - DtlsSrtpHandshakeResult::Success(dtls_srtp) => Ok(Async::Ready(dtls_srtp)), - mut new @ DtlsSrtpHandshakeResult::WouldBlock(_) => { - std::mem::swap(&mut new, self); - Ok(Async::NotReady) - } - DtlsSrtpHandshakeResult::Failure(err) => Err(err), - }, - DtlsSrtpHandshakeResult::Failure(err) => Err(err), - } - } -} - -impl Stream for DtlsSrtp -where - S: AsyncRead + AsyncWrite, - D: DtlsBuilder>, - DtlsSrtp: AsyncRead + AsyncWrite, -{ - type Item = Vec; - type Error = io::Error; - - fn poll(&mut self) -> io::Result>> { - let mut buf = [0; 2048]; - Ok(match self.poll_read(&mut buf)? { - Async::Ready(len) => { - if len == 0 { - Async::Ready(None) - } else { - Async::Ready(Some(buf[..len].to_vec())) - } - } - Async::NotReady => Async::NotReady, - }) - } -} - -impl Sink for DtlsSrtp -where - S: AsyncRead + AsyncWrite, - D: DtlsBuilder>, - DtlsSrtp: AsyncRead + AsyncWrite, -{ - type SinkItem = Vec; - type SinkError = io::Error; - - fn start_send(&mut self, buf: Self::SinkItem) -> io::Result> { - Ok(match self.poll_write(&buf[..])? { - Async::Ready(_) => AsyncSink::Ready, - Async::NotReady => AsyncSink::NotReady(buf), - }) - } - - fn poll_complete(&mut self) -> io::Result> { - self.poll_flush() - } -}