Update rfc5764 to use futures 0.3 AsyncRead/Write
With futures 0.3, the AsyncRead/Write traits are no longer marker traits, so we can no longer easily support both, non-blocking and blocking APIs. The code can however still be used in a blocking way buy just calling `now_or_never` on any futures it returns, which will block if the underlying streams are blocking.dtls-srtp
parent
3510417ce4
commit
1444b3c063
11
Cargo.toml
11
Cargo.toml
|
@ -4,17 +4,24 @@ version = "0.1.0"
|
||||||
authors = ["Takeru Ohta <phjgt308@gmail.com>"]
|
authors = ["Takeru Ohta <phjgt308@gmail.com>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
rfc5764-openssl = ["openssl", "tokio-openssl", "tokio-util/compat"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
trackable = "0.1"
|
trackable = "0.1"
|
||||||
handy_async = "0.2"
|
handy_async = "0.2"
|
||||||
rust-crypto = "0.2"
|
rust-crypto = "0.2"
|
||||||
num = "0.1"
|
num = "0.1"
|
||||||
fixedbitset = "0.1"
|
fixedbitset = "0.1"
|
||||||
|
futures = "0.3"
|
||||||
|
async-trait = "0.1"
|
||||||
|
|
||||||
openssl = { version = "0.10", optional = true }
|
openssl = { version = "0.10", optional = true }
|
||||||
tokio = { version = "0.1", optional = true }
|
tokio-openssl = { version = "0.4", optional = true }
|
||||||
|
tokio-util = { version = "0.3", optional = true, default-features = false }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
clap = "2"
|
clap = "2"
|
||||||
fibers = "0.1"
|
fibers = "0.1"
|
||||||
futures = "0.1"
|
futures01 = { package = "futures", version = "0.1" }
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
extern crate clap;
|
extern crate clap;
|
||||||
extern crate fibers;
|
extern crate fibers;
|
||||||
extern crate futures;
|
extern crate futures01 as futures;
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate trackable;
|
extern crate trackable;
|
||||||
extern crate rtp;
|
extern crate rtp;
|
||||||
|
|
|
@ -1,18 +1,24 @@
|
||||||
// FIXME: the current SRTP implementation does not support the maximum_lifetime parameter
|
// FIXME: the current SRTP implementation does not support the maximum_lifetime parameter
|
||||||
|
|
||||||
#[cfg(feature = "openssl")]
|
#[cfg(feature = "rfc5764-openssl")]
|
||||||
mod openssl;
|
mod openssl;
|
||||||
|
|
||||||
#[cfg(feature = "tokio")]
|
use async_trait::async_trait;
|
||||||
mod tokio;
|
use futures::io::{AsyncRead, AsyncWrite};
|
||||||
|
use futures::ready;
|
||||||
|
use futures::{Sink, Stream};
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::io::{Read, Write};
|
use std::io::Read;
|
||||||
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
use std::task::Context;
|
||||||
|
use std::task::Poll;
|
||||||
|
|
||||||
use crate::rfc3711::{AuthenticationAlgorithm, Context, EncryptionAlgorithm, Srtcp, Srtp};
|
use crate::rfc3711::{
|
||||||
|
AuthenticationAlgorithm, Context as SrtpContext, EncryptionAlgorithm, Srtcp, Srtp,
|
||||||
|
};
|
||||||
use crate::types::Ssrc;
|
use crate::types::Ssrc;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
@ -48,28 +54,16 @@ impl SrtpProtectionProfile {
|
||||||
&[&SrtpProtectionProfile::AES128_CM_HMAC_SHA1_80];
|
&[&SrtpProtectionProfile::AES128_CM_HMAC_SHA1_80];
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum DtlsHandshakeResult<Dtls, DtlsMidHandshake> {
|
#[async_trait]
|
||||||
Failure(io::Error),
|
|
||||||
WouldBlock(DtlsMidHandshake),
|
|
||||||
Success(Dtls),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait DtlsBuilder<S> {
|
pub trait DtlsBuilder<S> {
|
||||||
type Instance: Dtls<S>;
|
type Instance: Dtls<S>;
|
||||||
type MidHandshake: DtlsMidHandshake<S, Instance = Self::Instance>;
|
|
||||||
|
|
||||||
fn handshake(self, stream: S) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake>
|
async fn handshake(self, stream: S) -> Result<Self::Instance, io::Error>
|
||||||
where
|
where
|
||||||
S: Read + Write;
|
S: AsyncRead + AsyncWrite + Unpin + 'async_trait;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait DtlsMidHandshake<S>: Sized {
|
pub trait Dtls<S>: AsyncRead + AsyncWrite + Unpin {
|
||||||
type Instance: Dtls<S>;
|
|
||||||
|
|
||||||
fn handshake(self) -> DtlsHandshakeResult<Self::Instance, Self>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Dtls<S>: Read + Write {
|
|
||||||
fn is_server_side(&self) -> bool;
|
fn is_server_side(&self) -> bool;
|
||||||
fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec<u8>;
|
fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec<u8>;
|
||||||
}
|
}
|
||||||
|
@ -80,7 +74,7 @@ pub struct DtlsSrtpMuxer<S> {
|
||||||
srtp_buf: VecDeque<Vec<u8>>,
|
srtp_buf: VecDeque<Vec<u8>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: Read + Write> DtlsSrtpMuxer<S> {
|
impl<S: AsyncRead + AsyncWrite> DtlsSrtpMuxer<S> {
|
||||||
fn new(inner: S) -> Self {
|
fn new(inner: S) -> Self {
|
||||||
DtlsSrtpMuxer {
|
DtlsSrtpMuxer {
|
||||||
inner,
|
inner,
|
||||||
|
@ -102,8 +96,13 @@ impl<S> DtlsSrtpMuxer<S> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: Read> DtlsSrtpMuxer<S> {
|
impl<S: AsyncRead + Unpin> DtlsSrtpMuxer<S> {
|
||||||
fn read(&mut self, want_srtp: bool, dst_buf: &mut [u8]) -> io::Result<usize> {
|
fn read(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context,
|
||||||
|
want_srtp: bool,
|
||||||
|
dst_buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
{
|
{
|
||||||
let want_buf = if want_srtp {
|
let want_buf = if want_srtp {
|
||||||
&mut self.srtp_buf
|
&mut self.srtp_buf
|
||||||
|
@ -111,20 +110,20 @@ impl<S: Read> DtlsSrtpMuxer<S> {
|
||||||
&mut self.dtls_buf
|
&mut self.dtls_buf
|
||||||
};
|
};
|
||||||
if let Some(buf) = want_buf.pop_front() {
|
if let Some(buf) = want_buf.pop_front() {
|
||||||
return (&buf[..]).read(dst_buf);
|
return Poll::Ready((&buf[..]).read(dst_buf));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut buf = [0u8; 2048];
|
let mut buf = [0u8; 2048];
|
||||||
let len = self.inner.read(&mut buf)?;
|
let len = ready!(Pin::new(&mut self.inner).poll_read(cx, &mut buf))?;
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
return Ok(0);
|
return Poll::Ready(Ok(0));
|
||||||
}
|
}
|
||||||
let mut buf = &buf[..len];
|
let mut buf = &buf[..len];
|
||||||
// Demux SRTP and DTLS as per https://tools.ietf.org/html/rfc5764#section-5.1.2
|
// Demux SRTP and DTLS as per https://tools.ietf.org/html/rfc5764#section-5.1.2
|
||||||
let is_srtp = buf[0] >= 128 && buf[0] <= 191;
|
let is_srtp = buf[0] >= 128 && buf[0] <= 191;
|
||||||
if is_srtp == want_srtp {
|
if is_srtp == want_srtp {
|
||||||
buf.read(dst_buf)
|
Poll::Ready(buf.read(dst_buf))
|
||||||
} else {
|
} else {
|
||||||
if is_srtp {
|
if is_srtp {
|
||||||
&mut self.srtp_buf
|
&mut self.srtp_buf
|
||||||
|
@ -138,7 +137,7 @@ impl<S: Read> DtlsSrtpMuxer<S> {
|
||||||
// by pretending that we're doing non-blocking io (even if we aren't)
|
// by pretending that we're doing non-blocking io (even if we aren't)
|
||||||
// to get back to where we can enter the other (in the example: the dtls)
|
// to get back to where we can enter the other (in the example: the dtls)
|
||||||
// read-path and process the packet we just read.
|
// read-path and process the packet we just read.
|
||||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
Poll::Pending // FIXME this doesn't see sound, shouldn't we store the waker!?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -148,86 +147,61 @@ pub struct DtlsSrtpMuxerPart<S> {
|
||||||
srtp: bool,
|
srtp: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> Read for DtlsSrtpMuxerPart<S>
|
impl<S> AsyncRead for DtlsSrtpMuxerPart<S>
|
||||||
where
|
where
|
||||||
S: Read,
|
S: AsyncRead + Unpin,
|
||||||
{
|
{
|
||||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
fn poll_read(
|
||||||
self.muxer.lock().unwrap().read(self.srtp, buf)
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
self.muxer.lock().unwrap().read(cx, self.srtp, buf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> Write for DtlsSrtpMuxerPart<S>
|
impl<S> AsyncWrite for DtlsSrtpMuxerPart<S>
|
||||||
where
|
where
|
||||||
S: Write,
|
S: AsyncWrite + Unpin,
|
||||||
{
|
{
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
fn poll_write(
|
||||||
self.muxer.lock().unwrap().inner.write(buf)
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.muxer.lock().unwrap().inner).poll_write(cx, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
self.muxer.lock().unwrap().inner.flush()
|
Pin::new(&mut self.muxer.lock().unwrap().inner).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
|
ready!(self.as_mut().poll_flush(cx))?;
|
||||||
|
Pin::new(&mut self.muxer.lock().unwrap().inner).poll_close(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum DtlsSrtpHandshakeResult<S: Read + Write, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
|
pub struct DtlsSrtp<S: AsyncRead + AsyncWrite, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
|
||||||
Success(DtlsSrtp<S, D>),
|
|
||||||
WouldBlock(DtlsSrtpMidHandshake<S, D>),
|
|
||||||
Failure(io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct DtlsSrtpMidHandshake<S: Read + Write, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
|
|
||||||
stream: DtlsSrtpMuxerPart<S>,
|
|
||||||
dtls: D::MidHandshake,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct DtlsSrtp<S: Read + Write, D: DtlsBuilder<DtlsSrtpMuxerPart<S>>> {
|
|
||||||
stream: DtlsSrtpMuxerPart<S>,
|
stream: DtlsSrtpMuxerPart<S>,
|
||||||
|
#[allow(dead_code)] // we'll need this once we implement re-keying
|
||||||
dtls: D::Instance,
|
dtls: D::Instance,
|
||||||
srtp_read_context: Context<Srtp>,
|
srtp_read_context: SrtpContext<Srtp>,
|
||||||
srtcp_read_context: Context<Srtcp>,
|
srtcp_read_context: SrtpContext<Srtcp>,
|
||||||
srtp_write_context: Context<Srtp>,
|
srtp_write_context: SrtpContext<Srtp>,
|
||||||
srtcp_write_context: Context<Srtcp>,
|
srtcp_write_context: SrtpContext<Srtcp>,
|
||||||
}
|
sink_buf: Option<Vec<u8>>,
|
||||||
|
|
||||||
impl<S, D> DtlsSrtpMidHandshake<S, D>
|
|
||||||
where
|
|
||||||
S: Read + Write + Sized,
|
|
||||||
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
|
||||||
{
|
|
||||||
pub fn handshake(mut self) -> DtlsSrtpHandshakeResult<S, D> {
|
|
||||||
match self.dtls.handshake() {
|
|
||||||
DtlsHandshakeResult::Success(dtls) => {
|
|
||||||
DtlsSrtpHandshakeResult::Success(DtlsSrtp::new(self.stream, dtls))
|
|
||||||
}
|
|
||||||
DtlsHandshakeResult::WouldBlock(dtls) => {
|
|
||||||
self.dtls = dtls;
|
|
||||||
DtlsSrtpHandshakeResult::WouldBlock(self)
|
|
||||||
}
|
|
||||||
DtlsHandshakeResult::Failure(err) => DtlsSrtpHandshakeResult::Failure(err),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, D> DtlsSrtp<S, D>
|
impl<S, D> DtlsSrtp<S, D>
|
||||||
where
|
where
|
||||||
S: Read + Write,
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
||||||
{
|
{
|
||||||
pub fn handshake(stream: S, dtls_builder: D) -> DtlsSrtpHandshakeResult<S, D> {
|
pub async fn handshake(stream: S, dtls_builder: D) -> Result<DtlsSrtp<S, D>, io::Error> {
|
||||||
let (stream_dtls, stream_srtp) = DtlsSrtpMuxer::new(stream).into_parts();
|
let (stream_dtls, stream_srtp) = DtlsSrtpMuxer::new(stream).into_parts();
|
||||||
match dtls_builder.handshake(stream_dtls) {
|
let dtls = dtls_builder.handshake(stream_dtls).await?;
|
||||||
DtlsHandshakeResult::Success(dtls) => {
|
Ok(DtlsSrtp::new(stream_srtp, dtls))
|
||||||
DtlsSrtpHandshakeResult::Success(DtlsSrtp::new(stream_srtp, dtls))
|
|
||||||
}
|
|
||||||
DtlsHandshakeResult::WouldBlock(dtls) => {
|
|
||||||
DtlsSrtpHandshakeResult::WouldBlock(DtlsSrtpMidHandshake {
|
|
||||||
stream: stream_srtp,
|
|
||||||
dtls,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
DtlsHandshakeResult::Failure(err) => DtlsSrtpHandshakeResult::Failure(err),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new(stream: DtlsSrtpMuxerPart<S>, mut dtls: D::Instance) -> Self {
|
fn new(stream: DtlsSrtpMuxerPart<S>, mut dtls: D::Instance) -> Self {
|
||||||
|
@ -255,10 +229,11 @@ where
|
||||||
DtlsSrtp {
|
DtlsSrtp {
|
||||||
stream,
|
stream,
|
||||||
dtls,
|
dtls,
|
||||||
srtp_read_context: Context::new(&read_key, &read_salt),
|
srtp_read_context: SrtpContext::new(&read_key, &read_salt),
|
||||||
srtcp_read_context: Context::new(&read_key, &read_salt),
|
srtcp_read_context: SrtpContext::new(&read_key, &read_salt),
|
||||||
srtp_write_context: Context::new(&write_key, &write_salt),
|
srtp_write_context: SrtpContext::new(&write_key, &write_salt),
|
||||||
srtcp_write_context: Context::new(&write_key, &write_salt),
|
srtcp_write_context: SrtpContext::new(&write_key, &write_salt),
|
||||||
|
sink_buf: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -303,12 +278,33 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, D> Read for DtlsSrtp<S, D>
|
impl<S, D> AsyncRead for DtlsSrtp<S, D>
|
||||||
where
|
where
|
||||||
S: Read + Write,
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
||||||
{
|
{
|
||||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
let item = ready!(self.poll_next(cx)?);
|
||||||
|
if let Some(item) = item {
|
||||||
|
Poll::Ready((&item[..]).read(buf))
|
||||||
|
} else {
|
||||||
|
Poll::Ready(Ok(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, D> Stream for DtlsSrtp<S, D>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
||||||
|
{
|
||||||
|
type Item = io::Result<Vec<u8>>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
loop {
|
loop {
|
||||||
// Check if we have an SRTP packet in the queue
|
// Check if we have an SRTP packet in the queue
|
||||||
if self.stream.muxer.lock().unwrap().srtp_buf.is_empty() {
|
if self.stream.muxer.lock().unwrap().srtp_buf.is_empty() {
|
||||||
|
@ -331,13 +327,13 @@ where
|
||||||
|
|
||||||
// Read and handle the next SRTP packet from the queue
|
// Read and handle the next SRTP packet from the queue
|
||||||
let mut raw_buf = [0u8; 2048];
|
let mut raw_buf = [0u8; 2048];
|
||||||
let len = self.stream.read(&mut raw_buf)?;
|
let len = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut raw_buf))?;
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
return Ok(0);
|
return Poll::Ready(None);
|
||||||
}
|
}
|
||||||
let raw_buf = &raw_buf[..len];
|
let raw_buf = &raw_buf[..len];
|
||||||
return match self.process_incoming_srtp_packet(raw_buf) {
|
return match self.process_incoming_srtp_packet(raw_buf) {
|
||||||
Some(result) => (&result[..]).read(buf),
|
Some(result) => Poll::Ready(Some(Ok(result))),
|
||||||
None => {
|
None => {
|
||||||
// FIXME: check rfc for whether this should be dropped silently
|
// FIXME: check rfc for whether this should be dropped silently
|
||||||
continue; // packet failed to auth or decrypt, drop it and try the next one
|
continue; // packet failed to auth or decrypt, drop it and try the next one
|
||||||
|
@ -347,21 +343,68 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, D> Write for DtlsSrtp<S, D>
|
impl<S, D> AsyncWrite for DtlsSrtp<S, D>
|
||||||
where
|
where
|
||||||
S: Read + Write,
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
||||||
{
|
{
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
if let Some(buf) = self.process_outgoing_srtp_packet(buf) {
|
if let Some(buf) = self.process_outgoing_srtp_packet(buf) {
|
||||||
self.stream.write(&buf)
|
Pin::new(&mut self.stream).poll_write(cx, &buf)
|
||||||
} else {
|
} else {
|
||||||
Ok(buf.len())
|
Poll::Ready(Ok(buf.len()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
self.stream.flush()
|
Pin::new(&mut self.stream).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.stream).poll_close(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, D> Sink<&[u8]> for DtlsSrtp<S, D>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
||||||
|
{
|
||||||
|
type Error = io::Error;
|
||||||
|
|
||||||
|
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
|
let _ = Sink::poll_flush(self.as_mut(), cx)?;
|
||||||
|
if self.sink_buf.is_none() {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
} else {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(mut self: Pin<&mut Self>, item: &[u8]) -> io::Result<()> {
|
||||||
|
self.sink_buf = self.process_outgoing_srtp_packet(item.as_ref());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
|
if let Some(buf) = self.sink_buf.take() {
|
||||||
|
match Pin::new(&mut self.stream).poll_write(cx, &buf) {
|
||||||
|
Poll::Pending => {
|
||||||
|
self.sink_buf = Some(buf);
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Pin::new(&mut self.stream).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.stream).poll_close(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -373,6 +416,28 @@ pub(crate) mod test {
|
||||||
TEST_SRTP_SSRC,
|
TEST_SRTP_SSRC,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use futures::{AsyncReadExt, AsyncWriteExt, FutureExt};
|
||||||
|
|
||||||
|
macro_rules! read_now {
|
||||||
|
( $expr:expr, $buf:expr ) => {
|
||||||
|
$expr
|
||||||
|
.read($buf)
|
||||||
|
.now_or_never()
|
||||||
|
.expect("would block")
|
||||||
|
.expect("reading")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! write_now {
|
||||||
|
( $expr:expr, $buf:expr ) => {
|
||||||
|
$expr
|
||||||
|
.write($buf)
|
||||||
|
.now_or_never()
|
||||||
|
.expect("would block")
|
||||||
|
.expect("writing")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
struct DummyDtlsBuilder;
|
struct DummyDtlsBuilder;
|
||||||
struct DummyDtls<S> {
|
struct DummyDtls<S> {
|
||||||
connected: bool,
|
connected: bool,
|
||||||
|
@ -388,38 +453,31 @@ pub(crate) mod test {
|
||||||
DummyDtlsBuilder {}
|
DummyDtlsBuilder {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl<S: Read + Write> DtlsBuilder<S> for DummyDtlsBuilder {
|
#[async_trait]
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Unpin + Send> DtlsBuilder<S> for DummyDtlsBuilder {
|
||||||
type Instance = DummyDtls<S>;
|
type Instance = DummyDtls<S>;
|
||||||
type MidHandshake = DummyDtls<S>;
|
|
||||||
|
|
||||||
fn handshake(
|
async fn handshake(self, mut stream: S) -> Result<Self::Instance, io::Error>
|
||||||
self,
|
where
|
||||||
mut stream: S,
|
S: 'async_trait,
|
||||||
) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake> {
|
{
|
||||||
stream.write(DUMMY_DTLS_HELLO).unwrap();
|
let _ = stream.write(DUMMY_DTLS_HELLO).await;
|
||||||
DummyDtls {
|
let mut dtls = DummyDtls {
|
||||||
stream,
|
stream,
|
||||||
connected: false,
|
connected: false,
|
||||||
}
|
};
|
||||||
.handshake()
|
loop {
|
||||||
}
|
let _ = futures::poll!(dtls.read(&mut []));
|
||||||
}
|
if dtls.connected {
|
||||||
impl<S: Read + Write> DtlsMidHandshake<S> for DummyDtls<S> {
|
break;
|
||||||
type Instance = Self;
|
|
||||||
fn handshake(mut self) -> DtlsHandshakeResult<Self::Instance, Self> {
|
|
||||||
let result = self.read(&mut []).unwrap_err();
|
|
||||||
if result.kind() == io::ErrorKind::WouldBlock {
|
|
||||||
if self.connected {
|
|
||||||
DtlsHandshakeResult::Success(self)
|
|
||||||
} else {
|
} else {
|
||||||
DtlsHandshakeResult::WouldBlock(self)
|
futures::pending!();
|
||||||
}
|
|
||||||
} else {
|
|
||||||
DtlsHandshakeResult::Failure(result)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Ok(dtls)
|
||||||
}
|
}
|
||||||
impl<S: Read + Write> Dtls<S> for DummyDtls<S> {
|
}
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Unpin> Dtls<S> for DummyDtls<S> {
|
||||||
fn is_server_side(&self) -> bool {
|
fn is_server_side(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
@ -440,17 +498,21 @@ pub(crate) mod test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: Read + Write> Read for DummyDtls<S> {
|
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for DummyDtls<S> {
|
||||||
fn read(&mut self, _dst: &mut [u8]) -> io::Result<usize> {
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context,
|
||||||
|
_dst: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
loop {
|
loop {
|
||||||
let mut buf = [0u8; 2048];
|
let mut buf = [0u8; 2048];
|
||||||
let len = self.stream.read(&mut buf)?;
|
let len = ready!(Pin::new(&mut self.stream).poll_read(cx, &mut buf))?;
|
||||||
assert_eq!(len, 2);
|
assert_eq!(len, 2);
|
||||||
assert_eq!(buf[1], 42);
|
assert_eq!(buf[1], 42);
|
||||||
match &buf[..len] {
|
match &buf[..len] {
|
||||||
DUMMY_DTLS_NOOP => {}
|
DUMMY_DTLS_NOOP => {}
|
||||||
DUMMY_DTLS_HELLO => {
|
DUMMY_DTLS_HELLO => {
|
||||||
self.stream.write(DUMMY_DTLS_CONNECTED)?;
|
let _ = Pin::new(&mut self.stream).poll_write(cx, DUMMY_DTLS_CONNECTED)?;
|
||||||
}
|
}
|
||||||
DUMMY_DTLS_CONNECTED => {
|
DUMMY_DTLS_CONNECTED => {
|
||||||
self.connected = true;
|
self.connected = true;
|
||||||
|
@ -461,13 +523,21 @@ pub(crate) mod test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: Write> Write for DummyDtls<S> {
|
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for DummyDtls<S> {
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
fn poll_write(
|
||||||
self.stream.write(buf)
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.stream).poll_write(cx, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
self.stream.flush()
|
Pin::new(&mut self.stream).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.stream).poll_flush(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -498,38 +568,45 @@ pub(crate) mod test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Read for DummyTransport {
|
impl AsyncRead for DummyTransport {
|
||||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
match self.read_buf.lock().unwrap().pop_front() {
|
match self.read_buf.lock().unwrap().pop_front() {
|
||||||
None => Err(io::Error::new(io::ErrorKind::WouldBlock, "")),
|
None => Poll::Pending,
|
||||||
Some(elem) => (&mut &elem[..]).read(buf),
|
Some(elem) => Poll::Ready(std::io::Read::read(&mut &elem[..], buf)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Write for DummyTransport {
|
impl AsyncWrite for DummyTransport {
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
self.write_buf.lock().unwrap().push_back(buf.to_vec());
|
self.write_buf.lock().unwrap().push_back(buf.to_vec());
|
||||||
Ok(buf.len())
|
Poll::Ready(Ok(buf.len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
Ok(())
|
Poll::Ready(Ok(()))
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! assert_wouldblock {
|
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
( $expr:expr ) => {
|
Poll::Ready(Ok(()))
|
||||||
let err = $expr.unwrap_err();
|
}
|
||||||
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_dtls_srtp() -> (DummyTransport, DtlsSrtp<DummyTransport, DummyDtlsBuilder>) {
|
fn new_dtls_srtp() -> (DummyTransport, DtlsSrtp<DummyTransport, DummyDtlsBuilder>) {
|
||||||
let (mut stream, dummy_stream) = DummyTransport::new();
|
let (mut stream, dummy_stream) = DummyTransport::new();
|
||||||
stream.write(DUMMY_DTLS_CONNECTED).unwrap();
|
write_now!(stream, DUMMY_DTLS_CONNECTED);
|
||||||
match DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()) {
|
let mut dtls_srtp = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new())
|
||||||
DtlsSrtpHandshakeResult::Success(mut dtls_srtp) => {
|
.now_or_never()
|
||||||
|
.expect("DTLS-SRTP handshake did not complete")
|
||||||
|
.expect("DTL-SRTP handshake failed");
|
||||||
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO);
|
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO);
|
||||||
dtls_srtp.add_incoming_ssrc(TEST_SRTP_SSRC);
|
dtls_srtp.add_incoming_ssrc(TEST_SRTP_SSRC);
|
||||||
dtls_srtp.add_incoming_ssrc(TEST_SRTCP_SSRC);
|
dtls_srtp.add_incoming_ssrc(TEST_SRTCP_SSRC);
|
||||||
|
@ -537,37 +614,36 @@ pub(crate) mod test {
|
||||||
dtls_srtp.add_outgoing_ssrc(TEST_SRTCP_SSRC);
|
dtls_srtp.add_outgoing_ssrc(TEST_SRTCP_SSRC);
|
||||||
(stream, dtls_srtp)
|
(stream, dtls_srtp)
|
||||||
}
|
}
|
||||||
_ => panic!("DTLS-SRTP handshake failed"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn polls_dtls_layer_for_keys() {
|
fn polls_dtls_layer_for_keys() {
|
||||||
|
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
|
||||||
let (mut stream, dummy_stream) = DummyTransport::new();
|
let (mut stream, dummy_stream) = DummyTransport::new();
|
||||||
let handshake = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new());
|
let mut handshake = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()).boxed();
|
||||||
let handshake = match handshake {
|
match handshake.as_mut().poll(&mut cx) {
|
||||||
DtlsSrtpHandshakeResult::WouldBlock(it) => it,
|
Poll::Pending => {}
|
||||||
_ => panic!(),
|
_ => panic!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
stream.write(TEST_SRTP_PACKET).unwrap(); // too early, should be discarded
|
// too early, should be discarded
|
||||||
|
write_now!(stream, TEST_SRTP_PACKET);
|
||||||
|
|
||||||
let handshake = match handshake.handshake() {
|
match handshake.as_mut().poll(&mut cx) {
|
||||||
DtlsSrtpHandshakeResult::WouldBlock(it) => it,
|
Poll::Pending => {}
|
||||||
_ => panic!(),
|
_ => panic!(),
|
||||||
};
|
};
|
||||||
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO);
|
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO);
|
||||||
|
|
||||||
stream.write(DUMMY_DTLS_HELLO).unwrap();
|
write_now!(stream, DUMMY_DTLS_HELLO);
|
||||||
let handshake = match handshake.handshake() {
|
match handshake.as_mut().poll(&mut cx) {
|
||||||
DtlsSrtpHandshakeResult::WouldBlock(it) => it,
|
Poll::Pending => {}
|
||||||
_ => panic!(),
|
_ => panic!(),
|
||||||
};
|
};
|
||||||
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_CONNECTED);
|
assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_CONNECTED);
|
||||||
|
|
||||||
stream.write(DUMMY_DTLS_CONNECTED).unwrap();
|
write_now!(stream, DUMMY_DTLS_CONNECTED);
|
||||||
match handshake.handshake() {
|
match handshake.as_mut().poll(&mut cx) {
|
||||||
DtlsSrtpHandshakeResult::Success(_) => {}
|
Poll::Ready(_) => {}
|
||||||
_ => panic!(),
|
_ => panic!(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -577,10 +653,10 @@ pub(crate) mod test {
|
||||||
let mut buf = [0u8; 2048];
|
let mut buf = [0u8; 2048];
|
||||||
let (mut stream, mut dtls_srtp) = new_dtls_srtp();
|
let (mut stream, mut dtls_srtp) = new_dtls_srtp();
|
||||||
|
|
||||||
stream.write(TEST_SRTP_PACKET).unwrap();
|
write_now!(stream, TEST_SRTP_PACKET);
|
||||||
stream.write(TEST_SRTCP_PACKET).unwrap();
|
write_now!(stream, TEST_SRTCP_PACKET);
|
||||||
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 182); // srtp
|
assert_eq!(read_now!(dtls_srtp, &mut buf), 182); // srtp
|
||||||
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 68); // srtcp
|
assert_eq!(read_now!(dtls_srtp, &mut buf), 68); // srtcp
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -588,16 +664,16 @@ pub(crate) mod test {
|
||||||
let mut buf = [0u8; 2048];
|
let mut buf = [0u8; 2048];
|
||||||
let (mut stream, mut dtls_srtp) = new_dtls_srtp();
|
let (mut stream, mut dtls_srtp) = new_dtls_srtp();
|
||||||
|
|
||||||
stream.write(TEST_SRTP_PACKET).unwrap();
|
write_now!(stream, TEST_SRTP_PACKET);
|
||||||
stream.write(TEST_SRTP_PACKET).unwrap();
|
write_now!(stream, TEST_SRTP_PACKET);
|
||||||
stream.write(TEST_SRTP_PACKET).unwrap();
|
write_now!(stream, TEST_SRTP_PACKET);
|
||||||
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 182);
|
assert_eq!(read_now!(dtls_srtp, &mut buf), 182);
|
||||||
assert_wouldblock!(dtls_srtp.read(&mut buf));
|
assert!(dtls_srtp.read(&mut buf).now_or_never().is_none(),);
|
||||||
|
|
||||||
stream.write(TEST_SRTCP_PACKET).unwrap();
|
write_now!(stream, TEST_SRTCP_PACKET);
|
||||||
stream.write(TEST_SRTCP_PACKET).unwrap();
|
write_now!(stream, TEST_SRTCP_PACKET);
|
||||||
stream.write(TEST_SRTCP_PACKET).unwrap();
|
write_now!(stream, TEST_SRTCP_PACKET);
|
||||||
assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 68);
|
assert_eq!(read_now!(dtls_srtp, &mut buf), 68);
|
||||||
assert_wouldblock!(dtls_srtp.read(&mut buf));
|
assert!(dtls_srtp.read(&mut buf).now_or_never().is_none(),);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,19 +1,22 @@
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::io::{AsyncRead, AsyncWrite};
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::io::{Read, Write};
|
use tokio_openssl::SslStream;
|
||||||
|
use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, Tokio02AsyncReadCompatExt};
|
||||||
|
|
||||||
use openssl::ssl::{
|
use openssl::ssl::{ConnectConfiguration, SslAcceptorBuilder};
|
||||||
HandshakeError, MidHandshakeSslStream, SslAcceptorBuilder, SslConnectorBuilder, SslStream,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::rfc5764::{Dtls, DtlsBuilder, DtlsHandshakeResult, DtlsMidHandshake, SrtpProtectionProfile};
|
use crate::rfc5764::{Dtls, DtlsBuilder, SrtpProtectionProfile};
|
||||||
|
|
||||||
impl<S: Read + Write + Sync> DtlsBuilder<S> for SslConnectorBuilder {
|
type CompatSslStream<S> = Compat<SslStream<Compat<S>>>;
|
||||||
type Instance = SslStream<S>;
|
|
||||||
type MidHandshake = MidHandshakeSslStream<S>;
|
|
||||||
|
|
||||||
fn handshake(mut self, stream: S) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake>
|
#[async_trait]
|
||||||
|
impl<S: AsyncRead + AsyncWrite + Send + Unpin> DtlsBuilder<S> for ConnectConfiguration {
|
||||||
|
type Instance = CompatSslStream<S>;
|
||||||
|
|
||||||
|
async fn handshake(mut self, stream: S) -> Result<Self::Instance, io::Error>
|
||||||
where
|
where
|
||||||
S: Read + Write,
|
S: 'async_trait,
|
||||||
{
|
{
|
||||||
let profiles_str: String = SrtpProtectionProfile::RECOMMENDED
|
let profiles_str: String = SrtpProtectionProfile::RECOMMENDED
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -21,28 +24,20 @@ impl<S: Read + Write + Sync> DtlsBuilder<S> for SslConnectorBuilder {
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(":");
|
.join(":");
|
||||||
self.set_tlsext_use_srtp(&profiles_str).unwrap();
|
self.set_tlsext_use_srtp(&profiles_str).unwrap();
|
||||||
match self.build().connect("invalid", stream) {
|
match tokio_openssl::connect(self, "invalid", stream.compat()).await {
|
||||||
Ok(stream) => DtlsHandshakeResult::Success(stream),
|
Ok(stream) => Ok(stream.compat()),
|
||||||
Err(HandshakeError::WouldBlock(mid_handshake)) => {
|
Err(_) => Err(io::Error::new(io::ErrorKind::Other, "handshake error")),
|
||||||
DtlsHandshakeResult::WouldBlock(mid_handshake)
|
|
||||||
}
|
|
||||||
Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure(
|
|
||||||
io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()),
|
|
||||||
),
|
|
||||||
Err(HandshakeError::SetupFailure(err)) => {
|
|
||||||
DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: Read + Write + Sync> DtlsBuilder<S> for SslAcceptorBuilder {
|
#[async_trait]
|
||||||
type Instance = SslStream<S>;
|
impl<S: AsyncRead + AsyncWrite + Send + Unpin> DtlsBuilder<S> for SslAcceptorBuilder {
|
||||||
type MidHandshake = MidHandshakeSslStream<S>;
|
type Instance = CompatSslStream<S>;
|
||||||
|
|
||||||
fn handshake(mut self, stream: S) -> DtlsHandshakeResult<Self::Instance, Self::MidHandshake>
|
async fn handshake(mut self, stream: S) -> Result<Self::Instance, io::Error>
|
||||||
where
|
where
|
||||||
S: Read + Write,
|
S: 'async_trait,
|
||||||
{
|
{
|
||||||
let profiles_str: String = SrtpProtectionProfile::RECOMMENDED
|
let profiles_str: String = SrtpProtectionProfile::RECOMMENDED
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -50,48 +45,22 @@ impl<S: Read + Write + Sync> DtlsBuilder<S> for SslAcceptorBuilder {
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(":");
|
.join(":");
|
||||||
self.set_tlsext_use_srtp(&profiles_str).unwrap();
|
self.set_tlsext_use_srtp(&profiles_str).unwrap();
|
||||||
match self.build().accept(stream) {
|
match tokio_openssl::accept(&self.build(), stream.compat()).await {
|
||||||
Ok(stream) => DtlsHandshakeResult::Success(stream),
|
Ok(stream) => Ok(stream.compat()),
|
||||||
Err(HandshakeError::WouldBlock(mid_handshake)) => {
|
Err(_) => Err(io::Error::new(io::ErrorKind::Other, "handshake error")),
|
||||||
DtlsHandshakeResult::WouldBlock(mid_handshake)
|
|
||||||
}
|
|
||||||
Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure(
|
|
||||||
io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()),
|
|
||||||
),
|
|
||||||
Err(HandshakeError::SetupFailure(err)) => {
|
|
||||||
DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: Read + Write + Sync> DtlsMidHandshake<S> for MidHandshakeSslStream<S> {
|
impl<S: AsyncRead + AsyncWrite + Unpin> Dtls<S> for Compat<SslStream<Compat<S>>> {
|
||||||
type Instance = SslStream<S>;
|
|
||||||
|
|
||||||
fn handshake(self) -> DtlsHandshakeResult<Self::Instance, Self> {
|
|
||||||
match MidHandshakeSslStream::handshake(self) {
|
|
||||||
Ok(stream) => DtlsHandshakeResult::Success(stream),
|
|
||||||
Err(HandshakeError::WouldBlock(mid_handshake)) => {
|
|
||||||
DtlsHandshakeResult::WouldBlock(mid_handshake)
|
|
||||||
}
|
|
||||||
Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure(
|
|
||||||
io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()),
|
|
||||||
),
|
|
||||||
Err(HandshakeError::SetupFailure(err)) => {
|
|
||||||
DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: Read + Write> Dtls<S> for SslStream<S> {
|
|
||||||
fn is_server_side(&self) -> bool {
|
fn is_server_side(&self) -> bool {
|
||||||
self.ssl().is_server()
|
self.get_ref().ssl().is_server()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec<u8> {
|
fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec<u8> {
|
||||||
let mut vec = vec![0; length];
|
let mut vec = vec![0; length];
|
||||||
self.ssl()
|
self.get_mut()
|
||||||
|
.ssl()
|
||||||
.export_keying_material(&mut vec, exporter_label, None)
|
.export_keying_material(&mut vec, exporter_label, None)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
vec
|
vec
|
||||||
|
@ -101,14 +70,16 @@ impl<S: Read + Write> Dtls<S> for SslStream<S> {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use crate::rfc5764::test::DummyTransport;
|
use crate::rfc5764::test::DummyTransport;
|
||||||
use crate::rfc5764::{DtlsSrtp, DtlsSrtpHandshakeResult};
|
use crate::rfc5764::DtlsSrtp;
|
||||||
|
|
||||||
|
use futures::FutureExt;
|
||||||
use openssl::asn1::Asn1Time;
|
use openssl::asn1::Asn1Time;
|
||||||
use openssl::hash::MessageDigest;
|
use openssl::hash::MessageDigest;
|
||||||
use openssl::pkey::PKey;
|
use openssl::pkey::PKey;
|
||||||
use openssl::rsa::Rsa;
|
use openssl::rsa::Rsa;
|
||||||
use openssl::ssl::{SslAcceptor, SslConnector, SslMethod, SslVerifyMode};
|
use openssl::ssl::{SslAcceptor, SslConnector, SslMethod, SslVerifyMode};
|
||||||
use openssl::x509::X509;
|
use openssl::x509::X509;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn connect_and_establish_matching_key_material() {
|
fn connect_and_establish_matching_key_material() {
|
||||||
|
@ -134,21 +105,16 @@ mod test {
|
||||||
acceptor.set_private_key(&key).unwrap();
|
acceptor.set_private_key(&key).unwrap();
|
||||||
acceptor.set_verify(SslVerifyMode::NONE);
|
acceptor.set_verify(SslVerifyMode::NONE);
|
||||||
connector.set_verify(SslVerifyMode::NONE);
|
connector.set_verify(SslVerifyMode::NONE);
|
||||||
let mut handshake_server = DtlsSrtp::handshake(server_sock, acceptor);
|
let handshake_server = DtlsSrtp::handshake(server_sock, acceptor);
|
||||||
let mut handshake_client = DtlsSrtp::handshake(client_sock, connector);
|
let handshake_client =
|
||||||
|
DtlsSrtp::handshake(client_sock, connector.build().configure().unwrap());
|
||||||
|
let mut future = futures::future::join(handshake_server, handshake_client).boxed();
|
||||||
|
|
||||||
|
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
|
||||||
loop {
|
loop {
|
||||||
match (handshake_client, handshake_server) {
|
if let Poll::Ready((server, client)) = future.as_mut().poll(&mut cx) {
|
||||||
(DtlsSrtpHandshakeResult::Failure(err), _) => {
|
let server = server.unwrap();
|
||||||
panic!("Client error: {}", err);
|
let client = client.unwrap();
|
||||||
}
|
|
||||||
(_, DtlsSrtpHandshakeResult::Failure(err)) => {
|
|
||||||
panic!("Server error: {}", err);
|
|
||||||
}
|
|
||||||
(
|
|
||||||
DtlsSrtpHandshakeResult::Success(client),
|
|
||||||
DtlsSrtpHandshakeResult::Success(server),
|
|
||||||
) => {
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
client.srtp_read_context.master_key,
|
client.srtp_read_context.master_key,
|
||||||
server.srtp_write_context.master_key
|
server.srtp_write_context.master_key
|
||||||
|
@ -167,22 +133,6 @@ mod test {
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
(
|
|
||||||
DtlsSrtpHandshakeResult::WouldBlock(client),
|
|
||||||
DtlsSrtpHandshakeResult::WouldBlock(server),
|
|
||||||
) => {
|
|
||||||
handshake_client = client.handshake();
|
|
||||||
handshake_server = server.handshake();
|
|
||||||
}
|
|
||||||
(client, DtlsSrtpHandshakeResult::WouldBlock(server)) => {
|
|
||||||
handshake_client = client;
|
|
||||||
handshake_server = server.handshake();
|
|
||||||
}
|
|
||||||
(DtlsSrtpHandshakeResult::WouldBlock(client), server) => {
|
|
||||||
handshake_client = client.handshake();
|
|
||||||
handshake_server = server;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,95 +0,0 @@
|
||||||
use std::io;
|
|
||||||
use tokio::prelude::{Async, AsyncRead, AsyncSink, AsyncWrite, Future, Sink, Stream};
|
|
||||||
|
|
||||||
use crate::rfc5764::{DtlsBuilder, DtlsSrtp, DtlsSrtpHandshakeResult, DtlsSrtpMuxerPart};
|
|
||||||
|
|
||||||
impl<S, D> AsyncRead for DtlsSrtp<S, D>
|
|
||||||
where
|
|
||||||
S: AsyncRead + AsyncWrite,
|
|
||||||
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, D> AsyncWrite for DtlsSrtp<S, D>
|
|
||||||
where
|
|
||||||
S: AsyncRead + AsyncWrite,
|
|
||||||
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
|
||||||
{
|
|
||||||
fn shutdown(&mut self) -> io::Result<Async<()>> {
|
|
||||||
Ok(().into()) // FIXME
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, D> Future for DtlsSrtpHandshakeResult<S, D>
|
|
||||||
where
|
|
||||||
S: AsyncRead + AsyncWrite,
|
|
||||||
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
|
||||||
{
|
|
||||||
type Item = DtlsSrtp<S, D>;
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
fn poll(&mut self) -> io::Result<Async<Self::Item>> {
|
|
||||||
let mut owned = DtlsSrtpHandshakeResult::Failure(io::Error::new(
|
|
||||||
io::ErrorKind::Other,
|
|
||||||
"poll called after completion",
|
|
||||||
));
|
|
||||||
std::mem::swap(&mut owned, self);
|
|
||||||
match owned {
|
|
||||||
DtlsSrtpHandshakeResult::Success(dtls_srtp) => Ok(Async::Ready(dtls_srtp)),
|
|
||||||
DtlsSrtpHandshakeResult::WouldBlock(mid_handshake) => match mid_handshake.handshake() {
|
|
||||||
DtlsSrtpHandshakeResult::Success(dtls_srtp) => Ok(Async::Ready(dtls_srtp)),
|
|
||||||
mut new @ DtlsSrtpHandshakeResult::WouldBlock(_) => {
|
|
||||||
std::mem::swap(&mut new, self);
|
|
||||||
Ok(Async::NotReady)
|
|
||||||
}
|
|
||||||
DtlsSrtpHandshakeResult::Failure(err) => Err(err),
|
|
||||||
},
|
|
||||||
DtlsSrtpHandshakeResult::Failure(err) => Err(err),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, D> Stream for DtlsSrtp<S, D>
|
|
||||||
where
|
|
||||||
S: AsyncRead + AsyncWrite,
|
|
||||||
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
|
||||||
DtlsSrtp<S, D>: AsyncRead + AsyncWrite,
|
|
||||||
{
|
|
||||||
type Item = Vec<u8>;
|
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
fn poll(&mut self) -> io::Result<Async<Option<Self::Item>>> {
|
|
||||||
let mut buf = [0; 2048];
|
|
||||||
Ok(match self.poll_read(&mut buf)? {
|
|
||||||
Async::Ready(len) => {
|
|
||||||
if len == 0 {
|
|
||||||
Async::Ready(None)
|
|
||||||
} else {
|
|
||||||
Async::Ready(Some(buf[..len].to_vec()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Async::NotReady => Async::NotReady,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, D> Sink for DtlsSrtp<S, D>
|
|
||||||
where
|
|
||||||
S: AsyncRead + AsyncWrite,
|
|
||||||
D: DtlsBuilder<DtlsSrtpMuxerPart<S>>,
|
|
||||||
DtlsSrtp<S, D>: AsyncRead + AsyncWrite,
|
|
||||||
{
|
|
||||||
type SinkItem = Vec<u8>;
|
|
||||||
type SinkError = io::Error;
|
|
||||||
|
|
||||||
fn start_send(&mut self, buf: Self::SinkItem) -> io::Result<AsyncSink<Self::SinkItem>> {
|
|
||||||
Ok(match self.poll_write(&buf[..])? {
|
|
||||||
Async::Ready(_) => AsyncSink::Ready,
|
|
||||||
Async::NotReady => AsyncSink::NotReady(buf),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_complete(&mut self) -> io::Result<Async<()>> {
|
|
||||||
self.poll_flush()
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue