Update rfc5764 to use futures 0.3 AsyncRead/Write

With futures 0.3, the AsyncRead/Write traits are no longer marker traits, so we
can no longer easily support both, non-blocking and blocking APIs.
The code can however still be used in a blocking way buy just calling
`now_or_never` on any futures it returns, which will block if the underlying
streams are blocking.
dtls-srtp
Jonas Herzig 2020-04-05 13:36:15 +02:00
parent 3510417ce4
commit 1444b3c063
5 changed files with 333 additions and 395 deletions

View File

@ -4,17 +4,24 @@ version = "0.1.0"
authors = ["Takeru Ohta <phjgt308@gmail.com>"] authors = ["Takeru Ohta <phjgt308@gmail.com>"]
edition = "2018" edition = "2018"
[features]
default = []
rfc5764-openssl = ["openssl", "tokio-openssl", "tokio-util/compat"]
[dependencies] [dependencies]
trackable = "0.1" trackable = "0.1"
handy_async = "0.2" handy_async = "0.2"
rust-crypto = "0.2" rust-crypto = "0.2"
num = "0.1" num = "0.1"
fixedbitset = "0.1" fixedbitset = "0.1"
futures = "0.3"
async-trait = "0.1"
openssl = { version = "0.10", optional = true } openssl = { version = "0.10", optional = true }
tokio = { version = "0.1", optional = true } tokio-openssl = { version = "0.4", optional = true }
tokio-util = { version = "0.3", optional = true, default-features = false }
[dev-dependencies] [dev-dependencies]
clap = "2" clap = "2"
fibers = "0.1" fibers = "0.1"
futures = "0.1" futures01 = { package = "futures", version = "0.1" }

View File

@ -1,6 +1,6 @@
extern crate clap; extern crate clap;
extern crate fibers; extern crate fibers;
extern crate futures; extern crate futures01 as futures;
#[macro_use] #[macro_use]
extern crate trackable; extern crate trackable;
extern crate rtp; extern crate rtp;

View File

@ -1,18 +1,24 @@
// FIXME: the current SRTP implementation does not support the maximum_lifetime parameter // FIXME: the current SRTP implementation does not support the maximum_lifetime parameter
#[cfg(feature = "openssl")] #[cfg(feature = "rfc5764-openssl")]
mod openssl; mod openssl;
#[cfg(feature = "tokio")] use async_trait::async_trait;
mod tokio; use futures::io::{AsyncRead, AsyncWrite};
use futures::ready;
use futures::{Sink, Stream};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::io; use std::io;
use std::io::{Read, Write}; use std::io::Read;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
use std::task::Context;
use std::task::Poll;
use crate::rfc3711::{AuthenticationAlgorithm, Context, EncryptionAlgorithm, Srtcp, Srtp}; use crate::rfc3711::{
AuthenticationAlgorithm, Context as SrtpContext, EncryptionAlgorithm, Srtcp, Srtp,
};
use crate::types::Ssrc; use crate::types::Ssrc;
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@ -48,28 +54,16 @@ impl SrtpProtectionProfile {
&[&SrtpProtectionProfile::AES128_CM_HMAC_SHA1_80]; &[&SrtpProtectionProfile::AES128_CM_HMAC_SHA1_80];
} }
pub enum DtlsHandshakeResult<Dtls, DtlsMidHandshake> { #[async_trait]
Failure(io::Error),
WouldBlock(DtlsMidHandshake),
Success(Dtls),
}
pub trait DtlsBuilder<S> { pub trait DtlsBuilder<S> {
type Instance: Dtls<S>; type Instance: Dtls<S>;
type MidHandshake: DtlsMidHandshake<S, Instance = Self::Instance>;
fn handshake(self, stream: S) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake> async fn handshake(self, stream: S) -> Result<Self::Instance, io::Error>
where where
S: Read + Write; S: AsyncRead + AsyncWrite + Unpin + 'async_trait;
} }
pub trait DtlsMidHandshake<S>: Sized { pub trait Dtls<S>: AsyncRead + AsyncWrite + Unpin {
type Instance: Dtls<S>;
fn handshake(self) -> DtlsHandshakeResult<Self::Instance, Self>;
}
pub trait Dtls<S>: Read + Write {
fn is_server_side(&self) -> bool; fn is_server_side(&self) -> bool;
fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec<u8>; fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec<u8>;
} }
@ -80,7 +74,7 @@ pub struct DtlsSrtpMuxer<S> {
srtp_buf: VecDeque<Vec<u8>>, srtp_buf: VecDeque<Vec<u8>>,
} }
impl<S: Read + Write> DtlsSrtpMuxer<S> { impl<S: AsyncRead + AsyncWrite> DtlsSrtpMuxer<S> {
fn new(inner: S) -> Self { fn new(inner: S) -> Self {
DtlsSrtpMuxer { DtlsSrtpMuxer {
inner, inner,
@ -102,8 +96,13 @@ impl<S> DtlsSrtpMuxer<S> {
} }
} }
impl<S: Read> DtlsSrtpMuxer<S> { impl<S: AsyncRead + Unpin> DtlsSrtpMuxer<S> {
fn read(&mut self, want_srtp: bool, dst_buf: &mut [u8]) -> io::Result<usize> { fn read(
&mut self,
cx: &mut Context,
want_srtp: bool,
dst_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
{ {
let want_buf = if want_srtp { let want_buf = if want_srtp {
&mut self.srtp_buf &mut self.srtp_buf
@ -111,20 +110,20 @@ impl<S: Read> DtlsSrtpMuxer<S> {
&mut self.dtls_buf &mut self.dtls_buf
}; };
if let Some(buf) = want_buf.pop_front() { if let Some(buf) = want_buf.pop_front() {
return (&buf[..]).read(dst_buf); return Poll::Ready((&buf[..]).read(dst_buf));
} }
} }
let mut buf = [0u8; 2048]; let mut buf = [0u8; 2048];
let len = self.inner.read(&mut buf)?; let len = ready!(Pin::new(&mut self.inner).poll_read(cx, &mut buf))?;
if len == 0 { if len == 0 {
return Ok(0); return Poll::Ready(Ok(0));
} }
let mut buf = &buf[..len]; let mut buf = &buf[..len];
// Demux SRTP and DTLS as per https://tools.ietf.org/html/rfc5764#section-5.1.2 // Demux SRTP and DTLS as per https://tools.ietf.org/html/rfc5764#section-5.1.2
let is_srtp = buf[0] >= 128 && buf[0] <= 191; let is_srtp = buf[0] >= 128 && buf[0] <= 191;
if is_srtp == want_srtp { if is_srtp == want_srtp {
buf.read(dst_buf) Poll::Ready(buf.read(dst_buf))
} else { } else {
if is_srtp { if is_srtp {
&mut self.srtp_buf &mut self.srtp_buf
@ -138,7 +137,7 @@ impl<S: Read> DtlsSrtpMuxer<S> {
// by pretending that we're doing non-blocking io (even if we aren't) // by pretending that we're doing non-blocking io (even if we aren't)
// to get back to where we can enter the other (in the example: the dtls) // to get back to where we can enter the other (in the example: the dtls)
// read-path and process the packet we just read. // read-path and process the packet we just read.
Err(io::Error::new(io::ErrorKind::WouldBlock, "")) Poll::Pending // FIXME this doesn't see sound, shouldn't we store the waker!?
} }
} }
} }
@ -148,86 +147,61 @@ pub struct DtlsSrtpMuxerPart<S> {
srtp: bool, srtp: bool,
} }
impl<S> Read for DtlsSrtpMuxerPart<S> impl<S> AsyncRead for DtlsSrtpMuxerPart<S>
where where
S: Read, S: AsyncRead + Unpin,
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn poll_read(
self.muxer.lock().unwrap().read(self.srtp, buf) self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.muxer.lock().unwrap().read(cx, self.srtp, buf)
} }
} }
impl<S> Write for DtlsSrtpMuxerPart<S> impl<S> AsyncWrite for DtlsSrtpMuxerPart<S>
where where
S: Write, S: AsyncWrite + Unpin,
{ {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn poll_write(
self.muxer.lock().unwrap().inner.write(buf) mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.muxer.lock().unwrap().inner).poll_write(cx, buf)
} }
fn flush(&mut self) -> io::Result<()> { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.muxer.lock().unwrap().inner.flush() Pin::new(&mut self.muxer.lock().unwrap().inner).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
ready!(self.as_mut().poll_flush(cx))?;
Pin::new(&mut self.muxer.lock().unwrap().inner).poll_close(cx)
} }
} }
pub enum DtlsSrtpHandshakeResult<S: Read + Write, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> { pub struct DtlsSrtp<S: AsyncRead + AsyncWrite, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
Success(DtlsSrtp<S, D>),
WouldBlock(DtlsSrtpMidHandshake<S, D>),
Failure(io::Error),
}
pub struct DtlsSrtpMidHandshake<S: Read + Write, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
stream: DtlsSrtpMuxerPart<S>,
dtls: D::MidHandshake,
}
pub struct DtlsSrtp<S: Read + Write, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
stream: DtlsSrtpMuxerPart<S>, stream: DtlsSrtpMuxerPart<S>,
#[allow(dead_code)] // we'll need this once we implement re-keying
dtls: D::Instance, dtls: D::Instance,
srtp_read_context: Context<Srtp>, srtp_read_context: SrtpContext<Srtp>,
srtcp_read_context: Context<Srtcp>, srtcp_read_context: SrtpContext<Srtcp>,
srtp_write_context: Context<Srtp>, srtp_write_context: SrtpContext<Srtp>,
srtcp_write_context: Context<Srtcp>, srtcp_write_context: SrtpContext<Srtcp>,
} sink_buf: Option<Vec<u8>>,
impl<S, D> DtlsSrtpMidHandshake<S, D>
where
S: Read + Write + Sized,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
pub fn handshake(mut self) -> DtlsSrtpHandshakeResult<S, D> {
match self.dtls.handshake() {
DtlsHandshakeResult::Success(dtls) => {
DtlsSrtpHandshakeResult::Success(DtlsSrtp::new(self.stream, dtls))
}
DtlsHandshakeResult::WouldBlock(dtls) => {
self.dtls = dtls;
DtlsSrtpHandshakeResult::WouldBlock(self)
}
DtlsHandshakeResult::Failure(err) => DtlsSrtpHandshakeResult::Failure(err),
}
}
} }
impl<S, D> DtlsSrtp<S, D> impl<S, D> DtlsSrtp<S, D>
where where
S: Read + Write, S: AsyncRead + AsyncWrite + Unpin,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{ {
pub fn handshake(stream: S, dtls_builder: D) -> DtlsSrtpHandshakeResult<S, D> { pub async fn handshake(stream: S, dtls_builder: D) -> Result<DtlsSrtp<S, D>, io::Error> {
let (stream_dtls, stream_srtp) = DtlsSrtpMuxer::new(stream).into_parts(); let (stream_dtls, stream_srtp) = DtlsSrtpMuxer::new(stream).into_parts();
match dtls_builder.handshake(stream_dtls) { let dtls = dtls_builder.handshake(stream_dtls).await?;
DtlsHandshakeResult::Success(dtls) => { Ok(DtlsSrtp::new(stream_srtp, dtls))
DtlsSrtpHandshakeResult::Success(DtlsSrtp::new(stream_srtp, dtls))
}
DtlsHandshakeResult::WouldBlock(dtls) => {
DtlsSrtpHandshakeResult::WouldBlock(DtlsSrtpMidHandshake {
stream: stream_srtp,
dtls,
})
}
DtlsHandshakeResult::Failure(err) => DtlsSrtpHandshakeResult::Failure(err),
}
} }
fn new(stream: DtlsSrtpMuxerPart<S>, mut dtls: D::Instance) -> Self { fn new(stream: DtlsSrtpMuxerPart<S>, mut dtls: D::Instance) -> Self {
@ -255,10 +229,11 @@ where
DtlsSrtp { DtlsSrtp {
stream, stream,
dtls, dtls,
srtp_read_context: Context::new(&read_key, &read_salt), srtp_read_context: SrtpContext::new(&read_key, &read_salt),
srtcp_read_context: Context::new(&read_key, &read_salt), srtcp_read_context: SrtpContext::new(&read_key, &read_salt),
srtp_write_context: Context::new(&write_key, &write_salt), srtp_write_context: SrtpContext::new(&write_key, &write_salt),
srtcp_write_context: Context::new(&write_key, &write_salt), srtcp_write_context: SrtpContext::new(&write_key, &write_salt),
sink_buf: None,
} }
} }
@ -303,12 +278,33 @@ where
} }
} }
impl<S, D> Read for DtlsSrtp<S, D> impl<S, D> AsyncRead for DtlsSrtp<S, D>
where where
S: Read + Write, S: AsyncRead + AsyncWrite + Unpin,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let item = ready!(self.poll_next(cx)?);
if let Some(item) = item {
Poll::Ready((&item[..]).read(buf))
} else {
Poll::Ready(Ok(0))
}
}
}
impl<S, D> Stream for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite + Unpin,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
type Item = io::Result<Vec<u8>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop { loop {
// Check if we have an SRTP packet in the queue // Check if we have an SRTP packet in the queue
if self.stream.muxer.lock().unwrap().srtp_buf.is_empty() { if self.stream.muxer.lock().unwrap().srtp_buf.is_empty() {
@ -331,13 +327,13 @@ where
// Read and handle the next SRTP packet from the queue // Read and handle the next SRTP packet from the queue
let mut raw_buf = [0u8; 2048]; let mut raw_buf = [0u8; 2048];
let len = self.stream.read(&mut raw_buf)?; let len = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut raw_buf))?;
if len == 0 { if len == 0 {
return Ok(0); return Poll::Ready(None);
} }
let raw_buf = &raw_buf[..len]; let raw_buf = &raw_buf[..len];
return match self.process_incoming_srtp_packet(raw_buf) { return match self.process_incoming_srtp_packet(raw_buf) {
Some(result) => (&result[..]).read(buf), Some(result) => Poll::Ready(Some(Ok(result))),
None => { None => {
// FIXME: check rfc for whether this should be dropped silently // FIXME: check rfc for whether this should be dropped silently
continue; // packet failed to auth or decrypt, drop it and try the next one continue; // packet failed to auth or decrypt, drop it and try the next one
@ -347,21 +343,68 @@ where
} }
} }
impl<S, D> Write for DtlsSrtp<S, D> impl<S, D> AsyncWrite for DtlsSrtp<S, D>
where where
S: Read + Write, S: AsyncRead + AsyncWrite + Unpin,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{ {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if let Some(buf) = self.process_outgoing_srtp_packet(buf) { if let Some(buf) = self.process_outgoing_srtp_packet(buf) {
self.stream.write(&buf) Pin::new(&mut self.stream).poll_write(cx, &buf)
} else { } else {
Ok(buf.len()) Poll::Ready(Ok(buf.len()))
} }
} }
fn flush(&mut self) -> io::Result<()> { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.stream.flush() Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_close(cx)
}
}
impl<S, D> Sink<&[u8]> for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite + Unpin,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
type Error = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let _ = Sink::poll_flush(self.as_mut(), cx)?;
if self.sink_buf.is_none() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
fn start_send(mut self: Pin<&mut Self>, item: &[u8]) -> io::Result<()> {
self.sink_buf = self.process_outgoing_srtp_packet(item.as_ref());
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
if let Some(buf) = self.sink_buf.take() {
match Pin::new(&mut self.stream).poll_write(cx, &buf) {
Poll::Pending => {
self.sink_buf = Some(buf);
return Poll::Pending;
}
_ => {}
}
}
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_close(cx)
} }
} }
@ -373,6 +416,28 @@ pub(crate) mod test {
TEST_SRTP_SSRC, TEST_SRTP_SSRC,
}; };
use futures::{AsyncReadExt, AsyncWriteExt, FutureExt};
macro_rules! read_now {
( $expr:expr, $buf:expr ) => {
$expr
.read($buf)
.now_or_never()
.expect("would block")
.expect("reading")
};
}
macro_rules! write_now {
( $expr:expr, $buf:expr ) => {
$expr
.write($buf)
.now_or_never()
.expect("would block")
.expect("writing")
};
}
struct DummyDtlsBuilder; struct DummyDtlsBuilder;
struct DummyDtls<S> { struct DummyDtls<S> {
connected: bool, connected: bool,
@ -388,38 +453,31 @@ pub(crate) mod test {
DummyDtlsBuilder {} DummyDtlsBuilder {}
} }
} }
impl<S: Read + Write> DtlsBuilder<S> for DummyDtlsBuilder { #[async_trait]
impl<S: AsyncRead + AsyncWrite + Unpin + Send> DtlsBuilder<S> for DummyDtlsBuilder {
type Instance = DummyDtls<S>; type Instance = DummyDtls<S>;
type MidHandshake = DummyDtls<S>;
fn handshake( async fn handshake(self, mut stream: S) -> Result<Self::Instance, io::Error>
self, where
mut stream: S, S: 'async_trait,
) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake> { {
stream.write(DUMMY_DTLS_HELLO).unwrap(); let _ = stream.write(DUMMY_DTLS_HELLO).await;
DummyDtls { let mut dtls = DummyDtls {
stream, stream,
connected: false, connected: false,
} };
.handshake() loop {
} let _ = futures::poll!(dtls.read(&mut []));
} if dtls.connected {
impl<S: Read + Write> DtlsMidHandshake<S> for DummyDtls<S> { break;
type Instance = Self;
fn handshake(mut self) -> DtlsHandshakeResult<Self::Instance, Self> {
let result = self.read(&mut []).unwrap_err();
if result.kind() == io::ErrorKind::WouldBlock {
if self.connected {
DtlsHandshakeResult::Success(self)
} else { } else {
DtlsHandshakeResult::WouldBlock(self) futures::pending!();
}
} else {
DtlsHandshakeResult::Failure(result)
} }
} }
Ok(dtls)
} }
impl<S: Read + Write> Dtls<S> for DummyDtls<S> { }
impl<S: AsyncRead + AsyncWrite + Unpin> Dtls<S> for DummyDtls<S> {
fn is_server_side(&self) -> bool { fn is_server_side(&self) -> bool {
true true
} }
@ -440,17 +498,21 @@ pub(crate) mod test {
} }
} }
impl<S: Read + Write> Read for DummyDtls<S> { impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for DummyDtls<S> {
fn read(&mut self, _dst: &mut [u8]) -> io::Result<usize> { fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
_dst: &mut [u8],
) -> Poll<io::Result<usize>> {
loop { loop {
let mut buf = [0u8; 2048]; let mut buf = [0u8; 2048];
let len = self.stream.read(&mut buf)?; let len = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut buf))?;
assert_eq!(len, 2); assert_eq!(len, 2);
assert_eq!(buf[1], 42); assert_eq!(buf[1], 42);
match &buf[..len] { match &buf[..len] {
DUMMY_DTLS_NOOP => {} DUMMY_DTLS_NOOP => {}
DUMMY_DTLS_HELLO => { DUMMY_DTLS_HELLO => {
self.stream.write(DUMMY_DTLS_CONNECTED)?; let _ = Pin::new(&mut self.stream).poll_write(cx, DUMMY_DTLS_CONNECTED)?;
} }
DUMMY_DTLS_CONNECTED => { DUMMY_DTLS_CONNECTED => {
self.connected = true; self.connected = true;
@ -461,13 +523,21 @@ pub(crate) mod test {
} }
} }
impl<S: Write> Write for DummyDtls<S> { impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for DummyDtls<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn poll_write(
self.stream.write(buf) mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
} }
fn flush(&mut self) -> io::Result<()> { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.stream.flush() Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
} }
} }
@ -498,38 +568,45 @@ pub(crate) mod test {
} }
} }
impl Read for DummyTransport { impl AsyncRead for DummyTransport {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.read_buf.lock().unwrap().pop_front() { match self.read_buf.lock().unwrap().pop_front() {
None => Err(io::Error::new(io::ErrorKind::WouldBlock, "")), None => Poll::Pending,
Some(elem) => (&mut &elem[..]).read(buf), Some(elem) => Poll::Ready(std::io::Read::read(&mut &elem[..], buf)),
} }
} }
} }
impl Write for DummyTransport { impl AsyncWrite for DummyTransport {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.write_buf.lock().unwrap().push_back(buf.to_vec()); self.write_buf.lock().unwrap().push_back(buf.to_vec());
Ok(buf.len()) Poll::Ready(Ok(buf.len()))
} }
fn flush(&mut self) -> io::Result<()> { fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Ok(()) Poll::Ready(Ok(()))
}
} }
macro_rules! assert_wouldblock { fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
( $expr:expr ) => { Poll::Ready(Ok(()))
let err = $expr.unwrap_err(); }
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
};
} }
fn new_dtls_srtp() -> (DummyTransport, DtlsSrtp<DummyTransport, DummyDtlsBuilder>) { fn new_dtls_srtp() -> (DummyTransport, DtlsSrtp<DummyTransport, DummyDtlsBuilder>) {
let (mut stream, dummy_stream) = DummyTransport::new(); let (mut stream, dummy_stream) = DummyTransport::new();
stream.write(DUMMY_DTLS_CONNECTED).unwrap(); write_now!(stream, DUMMY_DTLS_CONNECTED);
match DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()) { let mut dtls_srtp = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new())
DtlsSrtpHandshakeResult::Success(mut dtls_srtp) => { .now_or_never()
.expect("DTLS-SRTP handshake did not complete")
.expect("DTL-SRTP handshake failed");
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO); assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO);
dtls_srtp.add_incoming_ssrc(TEST_SRTP_SSRC); dtls_srtp.add_incoming_ssrc(TEST_SRTP_SSRC);
dtls_srtp.add_incoming_ssrc(TEST_SRTCP_SSRC); dtls_srtp.add_incoming_ssrc(TEST_SRTCP_SSRC);
@ -537,37 +614,36 @@ pub(crate) mod test {
dtls_srtp.add_outgoing_ssrc(TEST_SRTCP_SSRC); dtls_srtp.add_outgoing_ssrc(TEST_SRTCP_SSRC);
(stream, dtls_srtp) (stream, dtls_srtp)
} }
_ => panic!("DTLS-SRTP handshake failed"),
}
}
#[test] #[test]
fn polls_dtls_layer_for_keys() { fn polls_dtls_layer_for_keys() {
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
let (mut stream, dummy_stream) = DummyTransport::new(); let (mut stream, dummy_stream) = DummyTransport::new();
let handshake = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()); let mut handshake = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()).boxed();
let handshake = match handshake { match handshake.as_mut().poll(&mut cx) {
DtlsSrtpHandshakeResult::WouldBlock(it) => it, Poll::Pending => {}
_ => panic!(), _ => panic!(),
}; };
stream.write(TEST_SRTP_PACKET).unwrap(); // too early, should be discarded // too early, should be discarded
write_now!(stream, TEST_SRTP_PACKET);
let handshake = match handshake.handshake() { match handshake.as_mut().poll(&mut cx) {
DtlsSrtpHandshakeResult::WouldBlock(it) => it, Poll::Pending => {}
_ => panic!(), _ => panic!(),
}; };
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO); assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO);
stream.write(DUMMY_DTLS_HELLO).unwrap(); write_now!(stream, DUMMY_DTLS_HELLO);
let handshake = match handshake.handshake() { match handshake.as_mut().poll(&mut cx) {
DtlsSrtpHandshakeResult::WouldBlock(it) => it, Poll::Pending => {}
_ => panic!(), _ => panic!(),
}; };
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_CONNECTED); assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_CONNECTED);
stream.write(DUMMY_DTLS_CONNECTED).unwrap(); write_now!(stream, DUMMY_DTLS_CONNECTED);
match handshake.handshake() { match handshake.as_mut().poll(&mut cx) {
DtlsSrtpHandshakeResult::Success(_) => {} Poll::Ready(_) => {}
_ => panic!(), _ => panic!(),
}; };
} }
@ -577,10 +653,10 @@ pub(crate) mod test {
let mut buf = [0u8; 2048]; let mut buf = [0u8; 2048];
let (mut stream, mut dtls_srtp) = new_dtls_srtp(); let (mut stream, mut dtls_srtp) = new_dtls_srtp();
stream.write(TEST_SRTP_PACKET).unwrap(); write_now!(stream, TEST_SRTP_PACKET);
stream.write(TEST_SRTCP_PACKET).unwrap(); write_now!(stream, TEST_SRTCP_PACKET);
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 182); // srtp assert_eq!(read_now!(dtls_srtp, &mut buf), 182); // srtp
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 68); // srtcp assert_eq!(read_now!(dtls_srtp, &mut buf), 68); // srtcp
} }
#[test] #[test]
@ -588,16 +664,16 @@ pub(crate) mod test {
let mut buf = [0u8; 2048]; let mut buf = [0u8; 2048];
let (mut stream, mut dtls_srtp) = new_dtls_srtp(); let (mut stream, mut dtls_srtp) = new_dtls_srtp();
stream.write(TEST_SRTP_PACKET).unwrap(); write_now!(stream, TEST_SRTP_PACKET);
stream.write(TEST_SRTP_PACKET).unwrap(); write_now!(stream, TEST_SRTP_PACKET);
stream.write(TEST_SRTP_PACKET).unwrap(); write_now!(stream, TEST_SRTP_PACKET);
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 182); assert_eq!(read_now!(dtls_srtp, &mut buf), 182);
assert_wouldblock!(dtls_srtp.read(&mut buf)); assert!(dtls_srtp.read(&mut buf).now_or_never().is_none(),);
stream.write(TEST_SRTCP_PACKET).unwrap(); write_now!(stream, TEST_SRTCP_PACKET);
stream.write(TEST_SRTCP_PACKET).unwrap(); write_now!(stream, TEST_SRTCP_PACKET);
stream.write(TEST_SRTCP_PACKET).unwrap(); write_now!(stream, TEST_SRTCP_PACKET);
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 68); assert_eq!(read_now!(dtls_srtp, &mut buf), 68);
assert_wouldblock!(dtls_srtp.read(&mut buf)); assert!(dtls_srtp.read(&mut buf).now_or_never().is_none(),);
} }
} }

View File

@ -1,19 +1,22 @@
use async_trait::async_trait;
use futures::io::{AsyncRead, AsyncWrite};
use std::io; use std::io;
use std::io::{Read, Write}; use tokio_openssl::SslStream;
use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, Tokio02AsyncReadCompatExt};
use openssl::ssl::{ use openssl::ssl::{ConnectConfiguration, SslAcceptorBuilder};
HandshakeError, MidHandshakeSslStream, SslAcceptorBuilder, SslConnectorBuilder, SslStream,
};
use crate::rfc5764::{Dtls, DtlsBuilder, DtlsHandshakeResult, DtlsMidHandshake, SrtpProtectionProfile}; use crate::rfc5764::{Dtls, DtlsBuilder, SrtpProtectionProfile};
impl<S: Read + Write + Sync> DtlsBuilder<S> for SslConnectorBuilder { type CompatSslStream<S> = Compat<SslStream<Compat<S>>>;
type Instance = SslStream<S>;
type MidHandshake = MidHandshakeSslStream<S>;
fn handshake(mut self, stream: S) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake> #[async_trait]
impl<S: AsyncRead + AsyncWrite + Send + Unpin> DtlsBuilder<S> for ConnectConfiguration {
type Instance = CompatSslStream<S>;
async fn handshake(mut self, stream: S) -> Result<Self::Instance, io::Error>
where where
S: Read + Write, S: 'async_trait,
{ {
let profiles_str: String = SrtpProtectionProfile::RECOMMENDED let profiles_str: String = SrtpProtectionProfile::RECOMMENDED
.iter() .iter()
@ -21,28 +24,20 @@ impl<S: Read + Write + Sync> DtlsBuilder<S> for SslConnectorBuilder {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(":"); .join(":");
self.set_tlsext_use_srtp(&profiles_str).unwrap(); self.set_tlsext_use_srtp(&profiles_str).unwrap();
match self.build().connect("invalid", stream) { match tokio_openssl::connect(self, "invalid", stream.compat()).await {
Ok(stream) => DtlsHandshakeResult::Success(stream), Ok(stream) => Ok(stream.compat()),
Err(HandshakeError::WouldBlock(mid_handshake)) => { Err(_) => Err(io::Error::new(io::ErrorKind::Other, "handshake error")),
DtlsHandshakeResult::WouldBlock(mid_handshake)
}
Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure(
io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()),
),
Err(HandshakeError::SetupFailure(err)) => {
DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err))
}
} }
} }
} }
impl<S: Read + Write + Sync> DtlsBuilder<S> for SslAcceptorBuilder { #[async_trait]
type Instance = SslStream<S>; impl<S: AsyncRead + AsyncWrite + Send + Unpin> DtlsBuilder<S> for SslAcceptorBuilder {
type MidHandshake = MidHandshakeSslStream<S>; type Instance = CompatSslStream<S>;
fn handshake(mut self, stream: S) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake> async fn handshake(mut self, stream: S) -> Result<Self::Instance, io::Error>
where where
S: Read + Write, S: 'async_trait,
{ {
let profiles_str: String = SrtpProtectionProfile::RECOMMENDED let profiles_str: String = SrtpProtectionProfile::RECOMMENDED
.iter() .iter()
@ -50,48 +45,22 @@ impl<S: Read + Write + Sync> DtlsBuilder<S> for SslAcceptorBuilder {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(":"); .join(":");
self.set_tlsext_use_srtp(&profiles_str).unwrap(); self.set_tlsext_use_srtp(&profiles_str).unwrap();
match self.build().accept(stream) { match tokio_openssl::accept(&self.build(), stream.compat()).await {
Ok(stream) => DtlsHandshakeResult::Success(stream), Ok(stream) => Ok(stream.compat()),
Err(HandshakeError::WouldBlock(mid_handshake)) => { Err(_) => Err(io::Error::new(io::ErrorKind::Other, "handshake error")),
DtlsHandshakeResult::WouldBlock(mid_handshake)
}
Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure(
io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()),
),
Err(HandshakeError::SetupFailure(err)) => {
DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err))
}
} }
} }
} }
impl<S: Read + Write + Sync> DtlsMidHandshake<S> for MidHandshakeSslStream<S> { impl<S: AsyncRead + AsyncWrite + Unpin> Dtls<S> for Compat<SslStream<Compat<S>>> {
type Instance = SslStream<S>;
fn handshake(self) -> DtlsHandshakeResult<Self::Instance, Self> {
match MidHandshakeSslStream::handshake(self) {
Ok(stream) => DtlsHandshakeResult::Success(stream),
Err(HandshakeError::WouldBlock(mid_handshake)) => {
DtlsHandshakeResult::WouldBlock(mid_handshake)
}
Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure(
io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()),
),
Err(HandshakeError::SetupFailure(err)) => {
DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err))
}
}
}
}
impl<S: Read + Write> Dtls<S> for SslStream<S> {
fn is_server_side(&self) -> bool { fn is_server_side(&self) -> bool {
self.ssl().is_server() self.get_ref().ssl().is_server()
} }
fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec<u8> { fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec<u8> {
let mut vec = vec![0; length]; let mut vec = vec![0; length];
self.ssl() self.get_mut()
.ssl()
.export_keying_material(&mut vec, exporter_label, None) .export_keying_material(&mut vec, exporter_label, None)
.unwrap(); .unwrap();
vec vec
@ -101,14 +70,16 @@ impl<S: Read + Write> Dtls<S> for SslStream<S> {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::rfc5764::test::DummyTransport; use crate::rfc5764::test::DummyTransport;
use crate::rfc5764::{DtlsSrtp, DtlsSrtpHandshakeResult}; use crate::rfc5764::DtlsSrtp;
use futures::FutureExt;
use openssl::asn1::Asn1Time; use openssl::asn1::Asn1Time;
use openssl::hash::MessageDigest; use openssl::hash::MessageDigest;
use openssl::pkey::PKey; use openssl::pkey::PKey;
use openssl::rsa::Rsa; use openssl::rsa::Rsa;
use openssl::ssl::{SslAcceptor, SslConnector, SslMethod, SslVerifyMode}; use openssl::ssl::{SslAcceptor, SslConnector, SslMethod, SslVerifyMode};
use openssl::x509::X509; use openssl::x509::X509;
use std::task::{Context, Poll};
#[test] #[test]
fn connect_and_establish_matching_key_material() { fn connect_and_establish_matching_key_material() {
@ -134,21 +105,16 @@ mod test {
acceptor.set_private_key(&key).unwrap(); acceptor.set_private_key(&key).unwrap();
acceptor.set_verify(SslVerifyMode::NONE); acceptor.set_verify(SslVerifyMode::NONE);
connector.set_verify(SslVerifyMode::NONE); connector.set_verify(SslVerifyMode::NONE);
let mut handshake_server = DtlsSrtp::handshake(server_sock, acceptor); let handshake_server = DtlsSrtp::handshake(server_sock, acceptor);
let mut handshake_client = DtlsSrtp::handshake(client_sock, connector); let handshake_client =
DtlsSrtp::handshake(client_sock, connector.build().configure().unwrap());
let mut future = futures::future::join(handshake_server, handshake_client).boxed();
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
loop { loop {
match (handshake_client, handshake_server) { if let Poll::Ready((server, client)) = future.as_mut().poll(&mut cx) {
(DtlsSrtpHandshakeResult::Failure(err), _) => { let server = server.unwrap();
panic!("Client error: {}", err); let client = client.unwrap();
}
(_, DtlsSrtpHandshakeResult::Failure(err)) => {
panic!("Server error: {}", err);
}
(
DtlsSrtpHandshakeResult::Success(client),
DtlsSrtpHandshakeResult::Success(server),
) => {
assert_eq!( assert_eq!(
client.srtp_read_context.master_key, client.srtp_read_context.master_key,
server.srtp_write_context.master_key server.srtp_write_context.master_key
@ -167,22 +133,6 @@ mod test {
); );
return; return;
} }
(
DtlsSrtpHandshakeResult::WouldBlock(client),
DtlsSrtpHandshakeResult::WouldBlock(server),
) => {
handshake_client = client.handshake();
handshake_server = server.handshake();
}
(client, DtlsSrtpHandshakeResult::WouldBlock(server)) => {
handshake_client = client;
handshake_server = server.handshake();
}
(DtlsSrtpHandshakeResult::WouldBlock(client), server) => {
handshake_client = client.handshake();
handshake_server = server;
}
}
} }
} }
} }

View File

@ -1,95 +0,0 @@
use std::io;
use tokio::prelude::{Async, AsyncRead, AsyncSink, AsyncWrite, Future, Sink, Stream};
use crate::rfc5764::{DtlsBuilder, DtlsSrtp, DtlsSrtpHandshakeResult, DtlsSrtpMuxerPart};
impl<S, D> AsyncRead for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
}
impl<S, D> AsyncWrite for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
fn shutdown(&mut self) -> io::Result<Async<()>> {
Ok(().into()) // FIXME
}
}
impl<S, D> Future for DtlsSrtpHandshakeResult<S, D>
where
S: AsyncRead + AsyncWrite,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
type Item = DtlsSrtp<S, D>;
type Error = io::Error;
fn poll(&mut self) -> io::Result<Async<Self::Item>> {
let mut owned = DtlsSrtpHandshakeResult::Failure(io::Error::new(
io::ErrorKind::Other,
"poll called after completion",
));
std::mem::swap(&mut owned, self);
match owned {
DtlsSrtpHandshakeResult::Success(dtls_srtp) => Ok(Async::Ready(dtls_srtp)),
DtlsSrtpHandshakeResult::WouldBlock(mid_handshake) => match mid_handshake.handshake() {
DtlsSrtpHandshakeResult::Success(dtls_srtp) => Ok(Async::Ready(dtls_srtp)),
mut new @ DtlsSrtpHandshakeResult::WouldBlock(_) => {
std::mem::swap(&mut new, self);
Ok(Async::NotReady)
}
DtlsSrtpHandshakeResult::Failure(err) => Err(err),
},
DtlsSrtpHandshakeResult::Failure(err) => Err(err),
}
}
}
impl<S, D> Stream for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
DtlsSrtp<S, D>: AsyncRead + AsyncWrite,
{
type Item = Vec<u8>;
type Error = io::Error;
fn poll(&mut self) -> io::Result<Async<Option<Self::Item>>> {
let mut buf = [0; 2048];
Ok(match self.poll_read(&mut buf)? {
Async::Ready(len) => {
if len == 0 {
Async::Ready(None)
} else {
Async::Ready(Some(buf[..len].to_vec()))
}
}
Async::NotReady => Async::NotReady,
})
}
}
impl<S, D> Sink for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
DtlsSrtp<S, D>: AsyncRead + AsyncWrite,
{
type SinkItem = Vec<u8>;
type SinkError = io::Error;
fn start_send(&mut self, buf: Self::SinkItem) -> io::Result<AsyncSink<Self::SinkItem>> {
Ok(match self.poll_write(&buf[..])? {
Async::Ready(_) => AsyncSink::Ready,
Async::NotReady => AsyncSink::NotReady(buf),
})
}
fn poll_complete(&mut self) -> io::Result<Async<()>> {
self.poll_flush()
}
}