Update rfc5764 to use futures 0.3 AsyncRead/Write

With futures 0.3, the AsyncRead/Write traits are no longer marker traits, so we
can no longer easily support both, non-blocking and blocking APIs.
The code can however still be used in a blocking way buy just calling
`now_or_never` on any futures it returns, which will block if the underlying
streams are blocking.
dtls-srtp
Jonas Herzig 2020-04-05 13:36:15 +02:00
parent 3510417ce4
commit 1444b3c063
5 changed files with 333 additions and 395 deletions

View File

@ -4,17 +4,24 @@ version = "0.1.0"
authors = ["Takeru Ohta <phjgt308@gmail.com>"]
edition = "2018"
[features]
default = []
rfc5764-openssl = ["openssl", "tokio-openssl", "tokio-util/compat"]
[dependencies]
trackable = "0.1"
handy_async = "0.2"
rust-crypto = "0.2"
num = "0.1"
fixedbitset = "0.1"
futures = "0.3"
async-trait = "0.1"
openssl = { version = "0.10", optional = true }
tokio = { version = "0.1", optional = true }
tokio-openssl = { version = "0.4", optional = true }
tokio-util = { version = "0.3", optional = true, default-features = false }
[dev-dependencies]
clap = "2"
fibers = "0.1"
futures = "0.1"
futures01 = { package = "futures", version = "0.1" }

View File

@ -1,6 +1,6 @@
extern crate clap;
extern crate fibers;
extern crate futures;
extern crate futures01 as futures;
#[macro_use]
extern crate trackable;
extern crate rtp;

View File

@ -1,18 +1,24 @@
// FIXME: the current SRTP implementation does not support the maximum_lifetime parameter
#[cfg(feature = "openssl")]
#[cfg(feature = "rfc5764-openssl")]
mod openssl;
#[cfg(feature = "tokio")]
mod tokio;
use async_trait::async_trait;
use futures::io::{AsyncRead, AsyncWrite};
use futures::ready;
use futures::{Sink, Stream};
use std::collections::VecDeque;
use std::io;
use std::io::{Read, Write};
use std::io::Read;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::task::Context;
use std::task::Poll;
use crate::rfc3711::{AuthenticationAlgorithm, Context, EncryptionAlgorithm, Srtcp, Srtp};
use crate::rfc3711::{
AuthenticationAlgorithm, Context as SrtpContext, EncryptionAlgorithm, Srtcp, Srtp,
};
use crate::types::Ssrc;
#[derive(Debug, Clone, PartialEq, Eq)]
@ -48,28 +54,16 @@ impl SrtpProtectionProfile {
&[&SrtpProtectionProfile::AES128_CM_HMAC_SHA1_80];
}
pub enum DtlsHandshakeResult<Dtls, DtlsMidHandshake> {
Failure(io::Error),
WouldBlock(DtlsMidHandshake),
Success(Dtls),
}
#[async_trait]
pub trait DtlsBuilder<S> {
type Instance: Dtls<S>;
type MidHandshake: DtlsMidHandshake<S, Instance = Self::Instance>;
fn handshake(self, stream: S) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake>
async fn handshake(self, stream: S) -> Result<Self::Instance, io::Error>
where
S: Read + Write;
S: AsyncRead + AsyncWrite + Unpin + 'async_trait;
}
pub trait DtlsMidHandshake<S>: Sized {
type Instance: Dtls<S>;
fn handshake(self) -> DtlsHandshakeResult<Self::Instance, Self>;
}
pub trait Dtls<S>: Read + Write {
pub trait Dtls<S>: AsyncRead + AsyncWrite + Unpin {
fn is_server_side(&self) -> bool;
fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec<u8>;
}
@ -80,7 +74,7 @@ pub struct DtlsSrtpMuxer<S> {
srtp_buf: VecDeque<Vec<u8>>,
}
impl<S: Read + Write> DtlsSrtpMuxer<S> {
impl<S: AsyncRead + AsyncWrite> DtlsSrtpMuxer<S> {
fn new(inner: S) -> Self {
DtlsSrtpMuxer {
inner,
@ -102,8 +96,13 @@ impl<S> DtlsSrtpMuxer<S> {
}
}
impl<S: Read> DtlsSrtpMuxer<S> {
fn read(&mut self, want_srtp: bool, dst_buf: &mut [u8]) -> io::Result<usize> {
impl<S: AsyncRead + Unpin> DtlsSrtpMuxer<S> {
fn read(
&mut self,
cx: &mut Context,
want_srtp: bool,
dst_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
{
let want_buf = if want_srtp {
&mut self.srtp_buf
@ -111,20 +110,20 @@ impl<S: Read> DtlsSrtpMuxer<S> {
&mut self.dtls_buf
};
if let Some(buf) = want_buf.pop_front() {
return (&buf[..]).read(dst_buf);
return Poll::Ready((&buf[..]).read(dst_buf));
}
}
let mut buf = [0u8; 2048];
let len = self.inner.read(&mut buf)?;
let len = ready!(Pin::new(&mut self.inner).poll_read(cx, &mut buf))?;
if len == 0 {
return Ok(0);
return Poll::Ready(Ok(0));
}
let mut buf = &buf[..len];
// Demux SRTP and DTLS as per https://tools.ietf.org/html/rfc5764#section-5.1.2
let is_srtp = buf[0] >= 128 && buf[0] <= 191;
if is_srtp == want_srtp {
buf.read(dst_buf)
Poll::Ready(buf.read(dst_buf))
} else {
if is_srtp {
&mut self.srtp_buf
@ -138,7 +137,7 @@ impl<S: Read> DtlsSrtpMuxer<S> {
// by pretending that we're doing non-blocking io (even if we aren't)
// to get back to where we can enter the other (in the example: the dtls)
// read-path and process the packet we just read.
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
Poll::Pending // FIXME this doesn't see sound, shouldn't we store the waker!?
}
}
}
@ -148,86 +147,61 @@ pub struct DtlsSrtpMuxerPart<S> {
srtp: bool,
}
impl<S> Read for DtlsSrtpMuxerPart<S>
impl<S> AsyncRead for DtlsSrtpMuxerPart<S>
where
S: Read,
S: AsyncRead + Unpin,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.muxer.lock().unwrap().read(self.srtp, buf)
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.muxer.lock().unwrap().read(cx, self.srtp, buf)
}
}
impl<S> Write for DtlsSrtpMuxerPart<S>
impl<S> AsyncWrite for DtlsSrtpMuxerPart<S>
where
S: Write,
S: AsyncWrite + Unpin,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.muxer.lock().unwrap().inner.write(buf)
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.muxer.lock().unwrap().inner).poll_write(cx, buf)
}
fn flush(&mut self) -> io::Result<()> {
self.muxer.lock().unwrap().inner.flush()
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.muxer.lock().unwrap().inner).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
ready!(self.as_mut().poll_flush(cx))?;
Pin::new(&mut self.muxer.lock().unwrap().inner).poll_close(cx)
}
}
pub enum DtlsSrtpHandshakeResult<S: Read + Write, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
Success(DtlsSrtp<S, D>),
WouldBlock(DtlsSrtpMidHandshake<S, D>),
Failure(io::Error),
}
pub struct DtlsSrtpMidHandshake<S: Read + Write, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
stream: DtlsSrtpMuxerPart<S>,
dtls: D::MidHandshake,
}
pub struct DtlsSrtp<S: Read + Write, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
pub struct DtlsSrtp<S: AsyncRead + AsyncWrite, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
stream: DtlsSrtpMuxerPart<S>,
#[allow(dead_code)] // we'll need this once we implement re-keying
dtls: D::Instance,
srtp_read_context: Context<Srtp>,
srtcp_read_context: Context<Srtcp>,
srtp_write_context: Context<Srtp>,
srtcp_write_context: Context<Srtcp>,
}
impl<S, D> DtlsSrtpMidHandshake<S, D>
where
S: Read + Write + Sized,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
pub fn handshake(mut self) -> DtlsSrtpHandshakeResult<S, D> {
match self.dtls.handshake() {
DtlsHandshakeResult::Success(dtls) => {
DtlsSrtpHandshakeResult::Success(DtlsSrtp::new(self.stream, dtls))
}
DtlsHandshakeResult::WouldBlock(dtls) => {
self.dtls = dtls;
DtlsSrtpHandshakeResult::WouldBlock(self)
}
DtlsHandshakeResult::Failure(err) => DtlsSrtpHandshakeResult::Failure(err),
}
}
srtp_read_context: SrtpContext<Srtp>,
srtcp_read_context: SrtpContext<Srtcp>,
srtp_write_context: SrtpContext<Srtp>,
srtcp_write_context: SrtpContext<Srtcp>,
sink_buf: Option<Vec<u8>>,
}
impl<S, D> DtlsSrtp<S, D>
where
S: Read + Write,
S: AsyncRead + AsyncWrite + Unpin,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
pub fn handshake(stream: S, dtls_builder: D) -> DtlsSrtpHandshakeResult<S, D> {
pub async fn handshake(stream: S, dtls_builder: D) -> Result<DtlsSrtp<S, D>, io::Error> {
let (stream_dtls, stream_srtp) = DtlsSrtpMuxer::new(stream).into_parts();
match dtls_builder.handshake(stream_dtls) {
DtlsHandshakeResult::Success(dtls) => {
DtlsSrtpHandshakeResult::Success(DtlsSrtp::new(stream_srtp, dtls))
}
DtlsHandshakeResult::WouldBlock(dtls) => {
DtlsSrtpHandshakeResult::WouldBlock(DtlsSrtpMidHandshake {
stream: stream_srtp,
dtls,
})
}
DtlsHandshakeResult::Failure(err) => DtlsSrtpHandshakeResult::Failure(err),
}
let dtls = dtls_builder.handshake(stream_dtls).await?;
Ok(DtlsSrtp::new(stream_srtp, dtls))
}
fn new(stream: DtlsSrtpMuxerPart<S>, mut dtls: D::Instance) -> Self {
@ -255,10 +229,11 @@ where
DtlsSrtp {
stream,
dtls,
srtp_read_context: Context::new(&read_key, &read_salt),
srtcp_read_context: Context::new(&read_key, &read_salt),
srtp_write_context: Context::new(&write_key, &write_salt),
srtcp_write_context: Context::new(&write_key, &write_salt),
srtp_read_context: SrtpContext::new(&read_key, &read_salt),
srtcp_read_context: SrtpContext::new(&read_key, &read_salt),
srtp_write_context: SrtpContext::new(&write_key, &write_salt),
srtcp_write_context: SrtpContext::new(&write_key, &write_salt),
sink_buf: None,
}
}
@ -303,12 +278,33 @@ where
}
}
impl<S, D> Read for DtlsSrtp<S, D>
impl<S, D> AsyncRead for DtlsSrtp<S, D>
where
S: Read + Write,
S: AsyncRead + AsyncWrite + Unpin,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let item = ready!(self.poll_next(cx)?);
if let Some(item) = item {
Poll::Ready((&item[..]).read(buf))
} else {
Poll::Ready(Ok(0))
}
}
}
impl<S, D> Stream for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite + Unpin,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
type Item = io::Result<Vec<u8>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
// Check if we have an SRTP packet in the queue
if self.stream.muxer.lock().unwrap().srtp_buf.is_empty() {
@ -331,13 +327,13 @@ where
// Read and handle the next SRTP packet from the queue
let mut raw_buf = [0u8; 2048];
let len = self.stream.read(&mut raw_buf)?;
let len = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut raw_buf))?;
if len == 0 {
return Ok(0);
return Poll::Ready(None);
}
let raw_buf = &raw_buf[..len];
return match self.process_incoming_srtp_packet(raw_buf) {
Some(result) => (&result[..]).read(buf),
Some(result) => Poll::Ready(Some(Ok(result))),
None => {
// FIXME: check rfc for whether this should be dropped silently
continue; // packet failed to auth or decrypt, drop it and try the next one
@ -347,21 +343,68 @@ where
}
}
impl<S, D> Write for DtlsSrtp<S, D>
impl<S, D> AsyncWrite for DtlsSrtp<S, D>
where
S: Read + Write,
S: AsyncRead + AsyncWrite + Unpin,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if let Some(buf) = self.process_outgoing_srtp_packet(buf) {
self.stream.write(&buf)
Pin::new(&mut self.stream).poll_write(cx, &buf)
} else {
Ok(buf.len())
Poll::Ready(Ok(buf.len()))
}
}
fn flush(&mut self) -> io::Result<()> {
self.stream.flush()
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_close(cx)
}
}
impl<S, D> Sink<&[u8]> for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite + Unpin,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
type Error = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let _ = Sink::poll_flush(self.as_mut(), cx)?;
if self.sink_buf.is_none() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
fn start_send(mut self: Pin<&mut Self>, item: &[u8]) -> io::Result<()> {
self.sink_buf = self.process_outgoing_srtp_packet(item.as_ref());
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
if let Some(buf) = self.sink_buf.take() {
match Pin::new(&mut self.stream).poll_write(cx, &buf) {
Poll::Pending => {
self.sink_buf = Some(buf);
return Poll::Pending;
}
_ => {}
}
}
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_close(cx)
}
}
@ -373,6 +416,28 @@ pub(crate) mod test {
TEST_SRTP_SSRC,
};
use futures::{AsyncReadExt, AsyncWriteExt, FutureExt};
macro_rules! read_now {
( $expr:expr, $buf:expr ) => {
$expr
.read($buf)
.now_or_never()
.expect("would block")
.expect("reading")
};
}
macro_rules! write_now {
( $expr:expr, $buf:expr ) => {
$expr
.write($buf)
.now_or_never()
.expect("would block")
.expect("writing")
};
}
struct DummyDtlsBuilder;
struct DummyDtls<S> {
connected: bool,
@ -388,38 +453,31 @@ pub(crate) mod test {
DummyDtlsBuilder {}
}
}
impl<S: Read + Write> DtlsBuilder<S> for DummyDtlsBuilder {
#[async_trait]
impl<S: AsyncRead + AsyncWrite + Unpin + Send> DtlsBuilder<S> for DummyDtlsBuilder {
type Instance = DummyDtls<S>;
type MidHandshake = DummyDtls<S>;
fn handshake(
self,
mut stream: S,
) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake> {
stream.write(DUMMY_DTLS_HELLO).unwrap();
DummyDtls {
async fn handshake(self, mut stream: S) -> Result<Self::Instance, io::Error>
where
S: 'async_trait,
{
let _ = stream.write(DUMMY_DTLS_HELLO).await;
let mut dtls = DummyDtls {
stream,
connected: false,
}
.handshake()
}
}
impl<S: Read + Write> DtlsMidHandshake<S> for DummyDtls<S> {
type Instance = Self;
fn handshake(mut self) -> DtlsHandshakeResult<Self::Instance, Self> {
let result = self.read(&mut []).unwrap_err();
if result.kind() == io::ErrorKind::WouldBlock {
if self.connected {
DtlsHandshakeResult::Success(self)
};
loop {
let _ = futures::poll!(dtls.read(&mut []));
if dtls.connected {
break;
} else {
DtlsHandshakeResult::WouldBlock(self)
futures::pending!();
}
} else {
DtlsHandshakeResult::Failure(result)
}
Ok(dtls)
}
}
impl<S: Read + Write> Dtls<S> for DummyDtls<S> {
impl<S: AsyncRead + AsyncWrite + Unpin> Dtls<S> for DummyDtls<S> {
fn is_server_side(&self) -> bool {
true
}
@ -440,17 +498,21 @@ pub(crate) mod test {
}
}
impl<S: Read + Write> Read for DummyDtls<S> {
fn read(&mut self, _dst: &mut [u8]) -> io::Result<usize> {
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for DummyDtls<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
_dst: &mut [u8],
) -> Poll<io::Result<usize>> {
loop {
let mut buf = [0u8; 2048];
let len = self.stream.read(&mut buf)?;
let len = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut buf))?;
assert_eq!(len, 2);
assert_eq!(buf[1], 42);
match &buf[..len] {
DUMMY_DTLS_NOOP => {}
DUMMY_DTLS_HELLO => {
self.stream.write(DUMMY_DTLS_CONNECTED)?;
let _ = Pin::new(&mut self.stream).poll_write(cx, DUMMY_DTLS_CONNECTED)?;
}
DUMMY_DTLS_CONNECTED => {
self.connected = true;
@ -461,13 +523,21 @@ pub(crate) mod test {
}
}
impl<S: Write> Write for DummyDtls<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.stream.write(buf)
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for DummyDtls<S> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}
fn flush(&mut self) -> io::Result<()> {
self.stream.flush()
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
}
@ -498,76 +568,82 @@ pub(crate) mod test {
}
}
impl Read for DummyTransport {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
impl AsyncRead for DummyTransport {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.read_buf.lock().unwrap().pop_front() {
None => Err(io::Error::new(io::ErrorKind::WouldBlock, "")),
Some(elem) => (&mut &elem[..]).read(buf),
None => Poll::Pending,
Some(elem) => Poll::Ready(std::io::Read::read(&mut &elem[..], buf)),
}
}
}
impl Write for DummyTransport {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
impl AsyncWrite for DummyTransport {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.write_buf.lock().unwrap().push_back(buf.to_vec());
Ok(buf.len())
Poll::Ready(Ok(buf.len()))
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
macro_rules! assert_wouldblock {
( $expr:expr ) => {
let err = $expr.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
};
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
fn new_dtls_srtp() -> (DummyTransport, DtlsSrtp<DummyTransport, DummyDtlsBuilder>) {
let (mut stream, dummy_stream) = DummyTransport::new();
stream.write(DUMMY_DTLS_CONNECTED).unwrap();
match DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()) {
DtlsSrtpHandshakeResult::Success(mut dtls_srtp) => {
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO);
dtls_srtp.add_incoming_ssrc(TEST_SRTP_SSRC);
dtls_srtp.add_incoming_ssrc(TEST_SRTCP_SSRC);
dtls_srtp.add_outgoing_ssrc(TEST_SRTP_SSRC);
dtls_srtp.add_outgoing_ssrc(TEST_SRTCP_SSRC);
(stream, dtls_srtp)
}
_ => panic!("DTLS-SRTP handshake failed"),
}
write_now!(stream, DUMMY_DTLS_CONNECTED);
let mut dtls_srtp = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new())
.now_or_never()
.expect("DTLS-SRTP handshake did not complete")
.expect("DTL-SRTP handshake failed");
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO);
dtls_srtp.add_incoming_ssrc(TEST_SRTP_SSRC);
dtls_srtp.add_incoming_ssrc(TEST_SRTCP_SSRC);
dtls_srtp.add_outgoing_ssrc(TEST_SRTP_SSRC);
dtls_srtp.add_outgoing_ssrc(TEST_SRTCP_SSRC);
(stream, dtls_srtp)
}
#[test]
fn polls_dtls_layer_for_keys() {
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
let (mut stream, dummy_stream) = DummyTransport::new();
let handshake = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new());
let handshake = match handshake {
DtlsSrtpHandshakeResult::WouldBlock(it) => it,
let mut handshake = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()).boxed();
match handshake.as_mut().poll(&mut cx) {
Poll::Pending => {}
_ => panic!(),
};
stream.write(TEST_SRTP_PACKET).unwrap(); // too early, should be discarded
// too early, should be discarded
write_now!(stream, TEST_SRTP_PACKET);
let handshake = match handshake.handshake() {
DtlsSrtpHandshakeResult::WouldBlock(it) => it,
match handshake.as_mut().poll(&mut cx) {
Poll::Pending => {}
_ => panic!(),
};
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO);
stream.write(DUMMY_DTLS_HELLO).unwrap();
let handshake = match handshake.handshake() {
DtlsSrtpHandshakeResult::WouldBlock(it) => it,
write_now!(stream, DUMMY_DTLS_HELLO);
match handshake.as_mut().poll(&mut cx) {
Poll::Pending => {}
_ => panic!(),
};
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_CONNECTED);
stream.write(DUMMY_DTLS_CONNECTED).unwrap();
match handshake.handshake() {
DtlsSrtpHandshakeResult::Success(_) => {}
write_now!(stream, DUMMY_DTLS_CONNECTED);
match handshake.as_mut().poll(&mut cx) {
Poll::Ready(_) => {}
_ => panic!(),
};
}
@ -577,10 +653,10 @@ pub(crate) mod test {
let mut buf = [0u8; 2048];
let (mut stream, mut dtls_srtp) = new_dtls_srtp();
stream.write(TEST_SRTP_PACKET).unwrap();
stream.write(TEST_SRTCP_PACKET).unwrap();
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 182); // srtp
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 68); // srtcp
write_now!(stream, TEST_SRTP_PACKET);
write_now!(stream, TEST_SRTCP_PACKET);
assert_eq!(read_now!(dtls_srtp, &mut buf), 182); // srtp
assert_eq!(read_now!(dtls_srtp, &mut buf), 68); // srtcp
}
#[test]
@ -588,16 +664,16 @@ pub(crate) mod test {
let mut buf = [0u8; 2048];
let (mut stream, mut dtls_srtp) = new_dtls_srtp();
stream.write(TEST_SRTP_PACKET).unwrap();
stream.write(TEST_SRTP_PACKET).unwrap();
stream.write(TEST_SRTP_PACKET).unwrap();
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 182);
assert_wouldblock!(dtls_srtp.read(&mut buf));
write_now!(stream, TEST_SRTP_PACKET);
write_now!(stream, TEST_SRTP_PACKET);
write_now!(stream, TEST_SRTP_PACKET);
assert_eq!(read_now!(dtls_srtp, &mut buf), 182);
assert!(dtls_srtp.read(&mut buf).now_or_never().is_none(),);
stream.write(TEST_SRTCP_PACKET).unwrap();
stream.write(TEST_SRTCP_PACKET).unwrap();
stream.write(TEST_SRTCP_PACKET).unwrap();
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 68);
assert_wouldblock!(dtls_srtp.read(&mut buf));
write_now!(stream, TEST_SRTCP_PACKET);
write_now!(stream, TEST_SRTCP_PACKET);
write_now!(stream, TEST_SRTCP_PACKET);
assert_eq!(read_now!(dtls_srtp, &mut buf), 68);
assert!(dtls_srtp.read(&mut buf).now_or_never().is_none(),);
}
}

View File

@ -1,19 +1,22 @@
use async_trait::async_trait;
use futures::io::{AsyncRead, AsyncWrite};
use std::io;
use std::io::{Read, Write};
use tokio_openssl::SslStream;
use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, Tokio02AsyncReadCompatExt};
use openssl::ssl::{
HandshakeError, MidHandshakeSslStream, SslAcceptorBuilder, SslConnectorBuilder, SslStream,
};
use openssl::ssl::{ConnectConfiguration, SslAcceptorBuilder};
use crate::rfc5764::{Dtls, DtlsBuilder, DtlsHandshakeResult, DtlsMidHandshake, SrtpProtectionProfile};
use crate::rfc5764::{Dtls, DtlsBuilder, SrtpProtectionProfile};
impl<S: Read + Write + Sync> DtlsBuilder<S> for SslConnectorBuilder {
type Instance = SslStream<S>;
type MidHandshake = MidHandshakeSslStream<S>;
type CompatSslStream<S> = Compat<SslStream<Compat<S>>>;
fn handshake(mut self, stream: S) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake>
#[async_trait]
impl<S: AsyncRead + AsyncWrite + Send + Unpin> DtlsBuilder<S> for ConnectConfiguration {
type Instance = CompatSslStream<S>;
async fn handshake(mut self, stream: S) -> Result<Self::Instance, io::Error>
where
S: Read + Write,
S: 'async_trait,
{
let profiles_str: String = SrtpProtectionProfile::RECOMMENDED
.iter()
@ -21,28 +24,20 @@ impl<S: Read + Write + Sync> DtlsBuilder<S> for SslConnectorBuilder {
.collect::<Vec<_>>()
.join(":");
self.set_tlsext_use_srtp(&profiles_str).unwrap();
match self.build().connect("invalid", stream) {
Ok(stream) => DtlsHandshakeResult::Success(stream),
Err(HandshakeError::WouldBlock(mid_handshake)) => {
DtlsHandshakeResult::WouldBlock(mid_handshake)
}
Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure(
io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()),
),
Err(HandshakeError::SetupFailure(err)) => {
DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err))
}
match tokio_openssl::connect(self, "invalid", stream.compat()).await {
Ok(stream) => Ok(stream.compat()),
Err(_) => Err(io::Error::new(io::ErrorKind::Other, "handshake error")),
}
}
}
impl<S: Read + Write + Sync> DtlsBuilder<S> for SslAcceptorBuilder {
type Instance = SslStream<S>;
type MidHandshake = MidHandshakeSslStream<S>;
#[async_trait]
impl<S: AsyncRead + AsyncWrite + Send + Unpin> DtlsBuilder<S> for SslAcceptorBuilder {
type Instance = CompatSslStream<S>;
fn handshake(mut self, stream: S) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake>
async fn handshake(mut self, stream: S) -> Result<Self::Instance, io::Error>
where
S: Read + Write,
S: 'async_trait,
{
let profiles_str: String = SrtpProtectionProfile::RECOMMENDED
.iter()
@ -50,48 +45,22 @@ impl<S: Read + Write + Sync> DtlsBuilder<S> for SslAcceptorBuilder {
.collect::<Vec<_>>()
.join(":");
self.set_tlsext_use_srtp(&profiles_str).unwrap();
match self.build().accept(stream) {
Ok(stream) => DtlsHandshakeResult::Success(stream),
Err(HandshakeError::WouldBlock(mid_handshake)) => {
DtlsHandshakeResult::WouldBlock(mid_handshake)
}
Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure(
io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()),
),
Err(HandshakeError::SetupFailure(err)) => {
DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err))
}
match tokio_openssl::accept(&self.build(), stream.compat()).await {
Ok(stream) => Ok(stream.compat()),
Err(_) => Err(io::Error::new(io::ErrorKind::Other, "handshake error")),
}
}
}
impl<S: Read + Write + Sync> DtlsMidHandshake<S> for MidHandshakeSslStream<S> {
type Instance = SslStream<S>;
fn handshake(self) -> DtlsHandshakeResult<Self::Instance, Self> {
match MidHandshakeSslStream::handshake(self) {
Ok(stream) => DtlsHandshakeResult::Success(stream),
Err(HandshakeError::WouldBlock(mid_handshake)) => {
DtlsHandshakeResult::WouldBlock(mid_handshake)
}
Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure(
io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()),
),
Err(HandshakeError::SetupFailure(err)) => {
DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err))
}
}
}
}
impl<S: Read + Write> Dtls<S> for SslStream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin> Dtls<S> for Compat<SslStream<Compat<S>>> {
fn is_server_side(&self) -> bool {
self.ssl().is_server()
self.get_ref().ssl().is_server()
}
fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec<u8> {
let mut vec = vec![0; length];
self.ssl()
self.get_mut()
.ssl()
.export_keying_material(&mut vec, exporter_label, None)
.unwrap();
vec
@ -101,14 +70,16 @@ impl<S: Read + Write> Dtls<S> for SslStream<S> {
#[cfg(test)]
mod test {
use crate::rfc5764::test::DummyTransport;
use crate::rfc5764::{DtlsSrtp, DtlsSrtpHandshakeResult};
use crate::rfc5764::DtlsSrtp;
use futures::FutureExt;
use openssl::asn1::Asn1Time;
use openssl::hash::MessageDigest;
use openssl::pkey::PKey;
use openssl::rsa::Rsa;
use openssl::ssl::{SslAcceptor, SslConnector, SslMethod, SslVerifyMode};
use openssl::x509::X509;
use std::task::{Context, Poll};
#[test]
fn connect_and_establish_matching_key_material() {
@ -134,54 +105,33 @@ mod test {
acceptor.set_private_key(&key).unwrap();
acceptor.set_verify(SslVerifyMode::NONE);
connector.set_verify(SslVerifyMode::NONE);
let mut handshake_server = DtlsSrtp::handshake(server_sock, acceptor);
let mut handshake_client = DtlsSrtp::handshake(client_sock, connector);
let handshake_server = DtlsSrtp::handshake(server_sock, acceptor);
let handshake_client =
DtlsSrtp::handshake(client_sock, connector.build().configure().unwrap());
let mut future = futures::future::join(handshake_server, handshake_client).boxed();
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
loop {
match (handshake_client, handshake_server) {
(DtlsSrtpHandshakeResult::Failure(err), _) => {
panic!("Client error: {}", err);
}
(_, DtlsSrtpHandshakeResult::Failure(err)) => {
panic!("Server error: {}", err);
}
(
DtlsSrtpHandshakeResult::Success(client),
DtlsSrtpHandshakeResult::Success(server),
) => {
assert_eq!(
client.srtp_read_context.master_key,
server.srtp_write_context.master_key
);
assert_eq!(
client.srtp_read_context.master_salt,
server.srtp_write_context.master_salt
);
assert_eq!(
client.srtp_write_context.master_key,
server.srtp_read_context.master_key
);
assert_eq!(
client.srtp_write_context.master_salt,
server.srtp_read_context.master_salt
);
return;
}
(
DtlsSrtpHandshakeResult::WouldBlock(client),
DtlsSrtpHandshakeResult::WouldBlock(server),
) => {
handshake_client = client.handshake();
handshake_server = server.handshake();
}
(client, DtlsSrtpHandshakeResult::WouldBlock(server)) => {
handshake_client = client;
handshake_server = server.handshake();
}
(DtlsSrtpHandshakeResult::WouldBlock(client), server) => {
handshake_client = client.handshake();
handshake_server = server;
}
if let Poll::Ready((server, client)) = future.as_mut().poll(&mut cx) {
let server = server.unwrap();
let client = client.unwrap();
assert_eq!(
client.srtp_read_context.master_key,
server.srtp_write_context.master_key
);
assert_eq!(
client.srtp_read_context.master_salt,
server.srtp_write_context.master_salt
);
assert_eq!(
client.srtp_write_context.master_key,
server.srtp_read_context.master_key
);
assert_eq!(
client.srtp_write_context.master_salt,
server.srtp_read_context.master_salt
);
return;
}
}
}

View File

@ -1,95 +0,0 @@
use std::io;
use tokio::prelude::{Async, AsyncRead, AsyncSink, AsyncWrite, Future, Sink, Stream};
use crate::rfc5764::{DtlsBuilder, DtlsSrtp, DtlsSrtpHandshakeResult, DtlsSrtpMuxerPart};
impl<S, D> AsyncRead for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
}
impl<S, D> AsyncWrite for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
fn shutdown(&mut self) -> io::Result<Async<()>> {
Ok(().into()) // FIXME
}
}
impl<S, D> Future for DtlsSrtpHandshakeResult<S, D>
where
S: AsyncRead + AsyncWrite,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
{
type Item = DtlsSrtp<S, D>;
type Error = io::Error;
fn poll(&mut self) -> io::Result<Async<Self::Item>> {
let mut owned = DtlsSrtpHandshakeResult::Failure(io::Error::new(
io::ErrorKind::Other,
"poll called after completion",
));
std::mem::swap(&mut owned, self);
match owned {
DtlsSrtpHandshakeResult::Success(dtls_srtp) => Ok(Async::Ready(dtls_srtp)),
DtlsSrtpHandshakeResult::WouldBlock(mid_handshake) => match mid_handshake.handshake() {
DtlsSrtpHandshakeResult::Success(dtls_srtp) => Ok(Async::Ready(dtls_srtp)),
mut new @ DtlsSrtpHandshakeResult::WouldBlock(_) => {
std::mem::swap(&mut new, self);
Ok(Async::NotReady)
}
DtlsSrtpHandshakeResult::Failure(err) => Err(err),
},
DtlsSrtpHandshakeResult::Failure(err) => Err(err),
}
}
}
impl<S, D> Stream for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
DtlsSrtp<S, D>: AsyncRead + AsyncWrite,
{
type Item = Vec<u8>;
type Error = io::Error;
fn poll(&mut self) -> io::Result<Async<Option<Self::Item>>> {
let mut buf = [0; 2048];
Ok(match self.poll_read(&mut buf)? {
Async::Ready(len) => {
if len == 0 {
Async::Ready(None)
} else {
Async::Ready(Some(buf[..len].to_vec()))
}
}
Async::NotReady => Async::NotReady,
})
}
}
impl<S, D> Sink for DtlsSrtp<S, D>
where
S: AsyncRead + AsyncWrite,
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
DtlsSrtp<S, D>: AsyncRead + AsyncWrite,
{
type SinkItem = Vec<u8>;
type SinkError = io::Error;
fn start_send(&mut self, buf: Self::SinkItem) -> io::Result<AsyncSink<Self::SinkItem>> {
Ok(match self.poll_write(&buf[..])? {
Async::Ready(_) => AsyncSink::Ready,
Async::NotReady => AsyncSink::NotReady(buf),
})
}
fn poll_complete(&mut self) -> io::Result<Async<()>> {
self.poll_flush()
}
}