From 0a8ec94f7dbcf91dfd92f02aed868eb839a6a3c1 Mon Sep 17 00:00:00 2001 From: Jonas Herzig Date: Sun, 2 Dec 2018 14:15:34 +0100 Subject: [PATCH] Add DTLS-SRTP support (rfc5764) https://tools.ietf.org/html/rfc5764 --- Cargo.toml | 3 + README.md | 1 + src/lib.rs | 6 + src/rfc3711.rs | 15 +- src/rfc5764/mod.rs | 603 +++++++++++++++++++++++++++++++++++++++++ src/rfc5764/openssl.rs | 188 +++++++++++++ src/rfc5764/tokio.rs | 95 +++++++ 7 files changed, 904 insertions(+), 7 deletions(-) create mode 100644 src/rfc5764/mod.rs create mode 100644 src/rfc5764/openssl.rs create mode 100644 src/rfc5764/tokio.rs diff --git a/Cargo.toml b/Cargo.toml index 92d0174..cb9a7c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,9 @@ rust-crypto = "0.2" num = "0.1" fixedbitset = "0.1" +openssl = { version = "0.10", optional = true } +tokio = { version = "0.1", optional = true } + [dev-dependencies] clap = "2" fibers = "0.1" diff --git a/README.md b/README.md index e8f7802..5e56bb7 100644 --- a/README.md +++ b/README.md @@ -12,3 +12,4 @@ RFC - AVPF: https://tools.ietf.org/html/rfc4585 - SAVPF: https://tools.ietf.org/html/rfc5124 - Multiplexing RTP and RTCP: https://tools.ietf.org/html/rfc5761 +- DTLS-SRTP: https://tools.ietf.org/html/rfc5764 diff --git a/src/lib.rs b/src/lib.rs index a5cae5b..af1f26b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,11 @@ extern crate handy_async; extern crate num; extern crate fixedbitset; +#[cfg(feature = "openssl")] +extern crate openssl; +#[cfg(feature = "tokio")] +extern crate tokio; + pub use error::{Error, ErrorKind}; pub mod io; @@ -12,6 +17,7 @@ pub mod rfc3550; pub mod rfc3711; pub mod rfc4585; pub mod rfc5761; +pub mod rfc5764; pub mod traits; mod error; diff --git a/src/rfc3711.rs b/src/rfc3711.rs index b9969f5..77ff66a 100644 --- a/src/rfc3711.rs +++ b/src/rfc3711.rs @@ -1,3 +1,4 @@ +// FIXME: saveguard against two-time pad by running replay-protection on outgoing packets use crypto; use fixedbitset::FixedBitSet; use handy_async::sync_io::{ReadExt, WriteExt}; @@ -829,7 +830,7 @@ fn prf_n(master_key: &[u8], x: BigUint, n: usize) -> Vec { } #[cfg(test)] -mod test { +pub(crate) mod test { use super::*; use rfc3550; use rfc4585; @@ -855,14 +856,14 @@ mod test { assert_eq!(estimate(&context, 10001), i(roc_p1, 10001)); // roc+1 } - const TEST_MASTER_KEY: &[u8] = &[ + pub(crate) const TEST_MASTER_KEY: &[u8] = &[ 211, 77, 116, 243, 125, 116, 231, 95, 59, 219, 79, 118, 241, 189, 244, 119, ]; - const TEST_MASTER_SALT: &[u8] = &[ + pub(crate) const TEST_MASTER_SALT: &[u8] = &[ 127, 31, 227, 93, 120, 247, 126, 117, 231, 159, 123, 235, 95, 122, ]; - const TEST_SRTP_SSRC: Ssrc = 446919554; - const TEST_SRTP_PACKET: &[u8] = &[ + pub(crate) const TEST_SRTP_SSRC: Ssrc = 446919554; + pub(crate) const TEST_SRTP_PACKET: &[u8] = &[ 128, 0, 3, 92, 222, 161, 6, 76, 26, 163, 115, 130, 222, 0, 143, 87, 0, 227, 123, 91, 200, 238, 141, 220, 9, 191, 52, 111, 100, 62, 220, 158, 211, 79, 184, 199, 79, 182, 9, 248, 170, 82, 125, 152, 143, 206, 8, 152, 80, 207, 27, 183, 141, 77, 33, 60, 101, 180, 210, 146, 139, @@ -874,8 +875,8 @@ mod test { 7, 52, 191, 129, 239, 86, 78, 172, 229, 178, 112, 22, 125, 191, 164, 17, 193, 24, 152, 197, 146, 94, 74, 156, 171, 245, 239, 220, 205, 145, 206, ]; - const TEST_SRTCP_SSRC: Ssrc = 3270675037; - const TEST_SRTCP_PACKET: &[u8] = &[ + pub(crate) const TEST_SRTCP_SSRC: Ssrc = 3270675037; + pub(crate) const TEST_SRTCP_PACKET: &[u8] = &[ 128, 201, 0, 1, 194, 242, 138, 93, 177, 31, 99, 88, 187, 209, 173, 181, 135, 18, 79, 59, 119, 153, 115, 34, 75, 94, 96, 29, 32, 14, 118, 86, 145, 159, 203, 174, 225, 34, 196, 229, 39, 22, 174, 54, 198, 56, 179, 171, 111, 229, 48, 234, 138, 249, 127, 11, 86, 94, 40, 213, diff --git a/src/rfc5764/mod.rs b/src/rfc5764/mod.rs new file mode 100644 index 0000000..666ff36 --- /dev/null +++ b/src/rfc5764/mod.rs @@ -0,0 +1,603 @@ +// FIXME: the current SRTP implementation does not support the maximum_lifetime parameter + +#[cfg(feature = "openssl")] +mod openssl; + +#[cfg(feature = "tokio")] +mod tokio; + +use std::collections::VecDeque; +use std::io; +use std::io::{Read, Write}; +use std::sync::Arc; +use std::sync::Mutex; + +use rfc3711::{AuthenticationAlgorithm, Context, EncryptionAlgorithm, Srtcp, Srtp}; +use types::Ssrc; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SrtpProtectionProfile { + pub name: &'static str, + pub cipher: EncryptionAlgorithm, + pub cipher_key_length: u8, + pub cipher_salt_length: u8, + pub maximum_lifetime: u32, + pub auth_function: AuthenticationAlgorithm, + pub auth_key_length: u8, + pub auth_salt_length: u8, +} + +impl SrtpProtectionProfile { + pub const AES128_CM_HMAC_SHA1_80: SrtpProtectionProfile = SrtpProtectionProfile { + name: "SRTP_AES128_CM_SHA1_80", + cipher: EncryptionAlgorithm::AesCm, + cipher_key_length: 128, + cipher_salt_length: 112, + maximum_lifetime: 2 ^ 31, + auth_function: AuthenticationAlgorithm::HmacSha1, + auth_key_length: 160, + auth_salt_length: 80, + }; + // AES128_CM_HMAC_SHA1_32 is not supported due to recommendation in rfc3711#5.2 + // NULL_HMAC_SHA1_80 is not supported because the NULL cipher isn't implemented + // NULL_HMAC_SHA1_32 is not supported due to recommendation in rfc3711#5.2 (and lack of NULL) + + pub const RECOMMENDED: &'static [&'static SrtpProtectionProfile] = + &[&SrtpProtectionProfile::AES128_CM_HMAC_SHA1_80]; + pub const SUPPORTED: &'static [&'static SrtpProtectionProfile] = + &[&SrtpProtectionProfile::AES128_CM_HMAC_SHA1_80]; +} + +pub enum DtlsHandshakeResult { + Failure(io::Error), + WouldBlock(DtlsMidHandshake), + Success(Dtls), +} + +pub trait DtlsBuilder { + type Instance: Dtls; + type MidHandshake: DtlsMidHandshake; + + fn handshake(self, stream: S) -> DtlsHandshakeResult + where + S: Read + Write; +} + +pub trait DtlsMidHandshake: Sized { + type Instance: Dtls; + + fn handshake(self) -> DtlsHandshakeResult; +} + +pub trait Dtls: Read + Write { + fn is_server_side(&self) -> bool; + fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec; +} + +pub struct DtlsSrtpMuxer { + inner: S, + dtls_buf: VecDeque>, + srtp_buf: VecDeque>, +} + +impl DtlsSrtpMuxer { + fn new(inner: S) -> Self { + DtlsSrtpMuxer { + inner, + dtls_buf: VecDeque::new(), + srtp_buf: VecDeque::new(), + } + } +} + +impl DtlsSrtpMuxer { + fn into_parts(self) -> (DtlsSrtpMuxerPart, DtlsSrtpMuxerPart) { + let muxer = Arc::new(Mutex::new(self)); + let dtls = DtlsSrtpMuxerPart { + muxer: muxer.clone(), + srtp: false, + }; + let srtp = DtlsSrtpMuxerPart { muxer, srtp: true }; + (dtls, srtp) + } +} + +impl DtlsSrtpMuxer { + fn read(&mut self, want_srtp: bool, dst_buf: &mut [u8]) -> io::Result { + { + let want_buf = if want_srtp { + &mut self.srtp_buf + } else { + &mut self.dtls_buf + }; + if let Some(buf) = want_buf.pop_front() { + return (&buf[..]).read(dst_buf); + } + } + + let mut buf = [0u8; 2048]; + let len = self.inner.read(&mut buf)?; + if len == 0 { + return Ok(0); + } + let mut buf = &buf[..len]; + // Demux SRTP and DTLS as per https://tools.ietf.org/html/rfc5764#section-5.1.2 + let is_srtp = buf[0] >= 128 && buf[0] <= 191; + if is_srtp == want_srtp { + buf.read(dst_buf) + } else { + if is_srtp { + &mut self.srtp_buf + } else { + &mut self.dtls_buf + } + .push_back(buf.to_vec()); + // We have to make sure we're not waiting for, e.g., a srtp packet when + // we just got a dtls packet and the remote is waiting on a reply to it. + // So, to prevent this kind of deadlock, we abort the current read-path + // by pretending that we're doing non-blocking io (even if we aren't) + // to get back to where we can enter the other (in the example: the dtls) + // read-path and process the packet we just read. + Err(io::Error::new(io::ErrorKind::WouldBlock, "")) + } + } +} + +pub struct DtlsSrtpMuxerPart { + muxer: Arc>>, + srtp: bool, +} + +impl Read for DtlsSrtpMuxerPart +where + S: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.muxer.lock().unwrap().read(self.srtp, buf) + } +} + +impl Write for DtlsSrtpMuxerPart +where + S: Write, +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + self.muxer.lock().unwrap().inner.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.muxer.lock().unwrap().inner.flush() + } +} + +pub enum DtlsSrtpHandshakeResult>> { + Success(DtlsSrtp), + WouldBlock(DtlsSrtpMidHandshake), + Failure(io::Error), +} + +pub struct DtlsSrtpMidHandshake>> { + stream: DtlsSrtpMuxerPart, + dtls: D::MidHandshake, +} + +pub struct DtlsSrtp>> { + stream: DtlsSrtpMuxerPart, + dtls: D::Instance, + srtp_read_context: Context, + srtcp_read_context: Context, + srtp_write_context: Context, + srtcp_write_context: Context, +} + +impl DtlsSrtpMidHandshake +where + S: Read + Write + Sized, + D: DtlsBuilder>, +{ + pub fn handshake(mut self) -> DtlsSrtpHandshakeResult { + match self.dtls.handshake() { + DtlsHandshakeResult::Success(dtls) => { + DtlsSrtpHandshakeResult::Success(DtlsSrtp::new(self.stream, dtls)) + } + DtlsHandshakeResult::WouldBlock(dtls) => { + self.dtls = dtls; + DtlsSrtpHandshakeResult::WouldBlock(self) + } + DtlsHandshakeResult::Failure(err) => DtlsSrtpHandshakeResult::Failure(err), + } + } +} + +impl DtlsSrtp +where + S: Read + Write, + D: DtlsBuilder>, +{ + pub fn handshake(stream: S, dtls_builder: D) -> DtlsSrtpHandshakeResult { + let (stream_dtls, stream_srtp) = DtlsSrtpMuxer::new(stream).into_parts(); + match dtls_builder.handshake(stream_dtls) { + DtlsHandshakeResult::Success(dtls) => { + DtlsSrtpHandshakeResult::Success(DtlsSrtp::new(stream_srtp, dtls)) + } + DtlsHandshakeResult::WouldBlock(dtls) => { + DtlsSrtpHandshakeResult::WouldBlock(DtlsSrtpMidHandshake { + stream: stream_srtp, + dtls, + }) + } + DtlsHandshakeResult::Failure(err) => DtlsSrtpHandshakeResult::Failure(err), + } + } + + fn new(stream: DtlsSrtpMuxerPart, mut dtls: D::Instance) -> Self { + const EXPORTER_LABEL: &str = "EXTRACTOR-dtls_srtp"; + const KEY_LEN: usize = 16; + const SALT_LEN: usize = 14; + const EXPORT_LEN: usize = (KEY_LEN + SALT_LEN) * 2; + + let key_material = dtls.export_key(EXPORTER_LABEL, EXPORT_LEN); + let client_material = ( + &(&key_material[0..])[..KEY_LEN], + &(&key_material[KEY_LEN * 2..])[..SALT_LEN], + ); + let server_material = ( + &(&key_material[KEY_LEN..])[..KEY_LEN], + &(&key_material[KEY_LEN * 2 + SALT_LEN..])[..SALT_LEN], + ); + let (read_material, write_material) = if dtls.is_server_side() { + (client_material, server_material) + } else { + (server_material, client_material) + }; + let (read_key, read_salt) = read_material; + let (write_key, write_salt) = write_material; + DtlsSrtp { + stream, + dtls, + srtp_read_context: Context::new(&read_key, &read_salt), + srtcp_read_context: Context::new(&read_key, &read_salt), + srtp_write_context: Context::new(&write_key, &write_salt), + srtcp_write_context: Context::new(&write_key, &write_salt), + } + } + + pub fn add_incoming_ssrc(&mut self, ssrc: Ssrc) { + self.srtp_read_context.add_ssrc(ssrc); + self.srtcp_read_context.add_ssrc(ssrc); + } + + pub fn add_incoming_unknown_ssrcs(&mut self, count: usize) { + self.srtp_read_context.add_unknown_ssrcs(count); + self.srtcp_read_context.add_unknown_ssrcs(count); + } + + pub fn add_outgoing_ssrc(&mut self, ssrc: Ssrc) { + self.srtp_write_context.add_ssrc(ssrc); + self.srtcp_write_context.add_ssrc(ssrc); + } + + pub fn add_outgoing_unknown_ssrcs(&mut self, count: usize) { + self.srtp_write_context.add_unknown_ssrcs(count); + self.srtcp_write_context.add_unknown_ssrcs(count); + } + + fn process_incoming_srtp_packet(&mut self, buf: &[u8]) -> Option> { + // Demux SRTP and SRTCP packets as per https://tools.ietf.org/html/rfc5761#section-4 + let payload_type = buf[1] & 0x7f; + if 64 <= payload_type && payload_type <= 95 { + self.srtcp_read_context.process_incoming(buf).ok() + } else { + self.srtp_read_context.process_incoming(buf).ok() + } + } + + fn process_outgoing_srtp_packet(&mut self, buf: &[u8]) -> Option> { + // Demux SRTP and SRTCP packets as per https://tools.ietf.org/html/rfc5761#section-4 + let payload_type = buf[1] & 0x7f; + if 64 <= payload_type && payload_type <= 95 { + self.srtcp_write_context.process_outgoing(buf).ok() + } else { + self.srtp_write_context.process_outgoing(buf).ok() + } + } +} + +impl Read for DtlsSrtp +where + S: Read + Write, + D: DtlsBuilder>, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + loop { + // Check if we have an SRTP packet in the queue + if self.stream.muxer.lock().unwrap().srtp_buf.is_empty() { + // if we don't, then poll the dtls layer which will read from the + // underlying packet stream and produce either dtls data or fill + // the SRTP packet queue or fail due to WouldBlock + /* FIXME polling dtls eventually errs with a read timeout, for some reason + * it does indeed send repeated "Change Cipher Spec" and "Encrypted Handshake + * Message" as if its expecting a response to those but none is sent by FF? + match self.dtls.read(buf) { + Ok(len) => return Ok(len), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + // Either we're using non-blocking io and there's no more data + // available, or we received an SRTP packet which needs handling + } + err => return err, + }; + */ + } + + // Read and handle the next SRTP packet from the queue + let mut raw_buf = [0u8; 2048]; + let len = self.stream.read(&mut raw_buf)?; + if len == 0 { + return Ok(0); + } + let raw_buf = &raw_buf[..len]; + return match self.process_incoming_srtp_packet(raw_buf) { + Some(result) => (&result[..]).read(buf), + None => { + // FIXME: check rfc for whether this should be dropped silently + continue; // packet failed to auth or decrypt, drop it and try the next one + } + }; + } + } +} + +impl Write for DtlsSrtp +where + S: Read + Write, + D: DtlsBuilder>, +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + if let Some(buf) = self.process_outgoing_srtp_packet(buf) { + self.stream.write(&buf) + } else { + Ok(buf.len()) + } + } + + fn flush(&mut self) -> io::Result<()> { + self.stream.flush() + } +} + +#[cfg(test)] +pub(crate) mod test { + use super::*; + use rfc3711::test::{ + TEST_MASTER_KEY, TEST_MASTER_SALT, TEST_SRTCP_PACKET, TEST_SRTCP_SSRC, TEST_SRTP_PACKET, + TEST_SRTP_SSRC, + }; + + struct DummyDtlsBuilder; + struct DummyDtls { + connected: bool, + stream: S, + } + + const DUMMY_DTLS_NOOP: &[u8] = &[20, 42]; + const DUMMY_DTLS_HELLO: &[u8] = &[62, 42]; + const DUMMY_DTLS_CONNECTED: &[u8] = &[63, 42]; + + impl DummyDtlsBuilder { + fn new() -> Self { + DummyDtlsBuilder {} + } + } + impl DtlsBuilder for DummyDtlsBuilder { + type Instance = DummyDtls; + type MidHandshake = DummyDtls; + + fn handshake( + self, + mut stream: S, + ) -> DtlsHandshakeResult { + stream.write(DUMMY_DTLS_HELLO).unwrap(); + DummyDtls { + stream, + connected: false, + } + .handshake() + } + } + impl DtlsMidHandshake for DummyDtls { + type Instance = Self; + fn handshake(mut self) -> DtlsHandshakeResult { + let result = self.read(&mut []).unwrap_err(); + if result.kind() == io::ErrorKind::WouldBlock { + if self.connected { + DtlsHandshakeResult::Success(self) + } else { + DtlsHandshakeResult::WouldBlock(self) + } + } else { + DtlsHandshakeResult::Failure(result) + } + } + } + impl Dtls for DummyDtls { + fn is_server_side(&self) -> bool { + true + } + + fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec { + assert_eq!(exporter_label, "EXTRACTOR-dtls_srtp"); + let mut buf = Vec::new(); + buf.extend(TEST_MASTER_KEY); + buf.extend(TEST_MASTER_KEY); + let idx = buf.len() - 1; + buf[idx] = 0; + buf.extend(TEST_MASTER_SALT); + buf.extend(TEST_MASTER_SALT); + let idx = buf.len() - 1; + buf[idx] = 0; + assert_eq!(length, buf.len()); + buf + } + } + + impl Read for DummyDtls { + fn read(&mut self, _dst: &mut [u8]) -> io::Result { + loop { + let mut buf = [0u8; 2048]; + let len = self.stream.read(&mut buf)?; + assert_eq!(len, 2); + assert_eq!(buf[1], 42); + match &buf[..len] { + DUMMY_DTLS_NOOP => {} + DUMMY_DTLS_HELLO => { + self.stream.write(DUMMY_DTLS_CONNECTED)?; + } + DUMMY_DTLS_CONNECTED => { + self.connected = true; + } + _ => panic!(), + }; + } + } + } + + impl Write for DummyDtls { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.stream.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.stream.flush() + } + } + + type PacketBufArc = Arc>>>; + pub(crate) struct DummyTransport { + read_buf: PacketBufArc, + write_buf: PacketBufArc, + } + + impl DummyTransport { + pub fn new() -> (Self, Self) { + let read_buf = Arc::new(Mutex::new(VecDeque::new())); + let write_buf = Arc::new(Mutex::new(VecDeque::new())); + ( + DummyTransport { + read_buf: read_buf.clone(), + write_buf: write_buf.clone(), + }, + DummyTransport { + read_buf: write_buf.clone(), + write_buf: read_buf.clone(), + }, + ) + } + + pub fn read_packet(&mut self) -> Option> { + self.read_buf.lock().unwrap().pop_front() + } + } + + impl Read for DummyTransport { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self.read_buf.lock().unwrap().pop_front() { + None => Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Some(elem) => (&mut &elem[..]).read(buf), + } + } + } + + impl Write for DummyTransport { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.write_buf.lock().unwrap().push_back(buf.to_vec()); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + + macro_rules! assert_wouldblock { + ( $expr:expr ) => { + let err = $expr.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::WouldBlock); + }; + } + + fn new_dtls_srtp() -> (DummyTransport, DtlsSrtp) { + let (mut stream, dummy_stream) = DummyTransport::new(); + stream.write(DUMMY_DTLS_CONNECTED).unwrap(); + match DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()) { + DtlsSrtpHandshakeResult::Success(mut dtls_srtp) => { + assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO); + dtls_srtp.add_incoming_ssrc(TEST_SRTP_SSRC); + dtls_srtp.add_incoming_ssrc(TEST_SRTCP_SSRC); + dtls_srtp.add_outgoing_ssrc(TEST_SRTP_SSRC); + dtls_srtp.add_outgoing_ssrc(TEST_SRTCP_SSRC); + (stream, dtls_srtp) + } + _ => panic!("DTLS-SRTP handshake failed"), + } + } + + #[test] + fn polls_dtls_layer_for_keys() { + let (mut stream, dummy_stream) = DummyTransport::new(); + let handshake = DtlsSrtp::handshake(dummy_stream, DummyDtlsBuilder::new()); + let handshake = match handshake { + DtlsSrtpHandshakeResult::WouldBlock(it) => it, + _ => panic!(), + }; + + stream.write(TEST_SRTP_PACKET).unwrap(); // too early, should be discarded + + let handshake = match handshake.handshake() { + DtlsSrtpHandshakeResult::WouldBlock(it) => it, + _ => panic!(), + }; + assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_HELLO); + + stream.write(DUMMY_DTLS_HELLO).unwrap(); + let handshake = match handshake.handshake() { + DtlsSrtpHandshakeResult::WouldBlock(it) => it, + _ => panic!(), + }; + assert_eq!(&stream.read_packet().unwrap()[..], DUMMY_DTLS_CONNECTED); + + stream.write(DUMMY_DTLS_CONNECTED).unwrap(); + match handshake.handshake() { + DtlsSrtpHandshakeResult::Success(_) => {} + _ => panic!(), + }; + } + + #[test] + fn decryption_of_incoming_srtp_and_srtcp_works() { + let mut buf = [0u8; 2048]; + let (mut stream, mut dtls_srtp) = new_dtls_srtp(); + + stream.write(TEST_SRTP_PACKET).unwrap(); + stream.write(TEST_SRTCP_PACKET).unwrap(); + assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 182); // srtp + assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 68); // srtcp + } + + #[test] + fn does_not_allow_replay_of_packets() { + let mut buf = [0u8; 2048]; + let (mut stream, mut dtls_srtp) = new_dtls_srtp(); + + stream.write(TEST_SRTP_PACKET).unwrap(); + stream.write(TEST_SRTP_PACKET).unwrap(); + stream.write(TEST_SRTP_PACKET).unwrap(); + assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 182); + assert_wouldblock!(dtls_srtp.read(&mut buf)); + + stream.write(TEST_SRTCP_PACKET).unwrap(); + stream.write(TEST_SRTCP_PACKET).unwrap(); + stream.write(TEST_SRTCP_PACKET).unwrap(); + assert_eq!(dtls_srtp.read(&mut buf).unwrap(), 68); + assert_wouldblock!(dtls_srtp.read(&mut buf)); + } +} diff --git a/src/rfc5764/openssl.rs b/src/rfc5764/openssl.rs new file mode 100644 index 0000000..8e4a760 --- /dev/null +++ b/src/rfc5764/openssl.rs @@ -0,0 +1,188 @@ +use std::io; +use std::io::{Read, Write}; + +use openssl::ssl::{ + HandshakeError, MidHandshakeSslStream, SslAcceptorBuilder, SslConnectorBuilder, SslStream, +}; + +use rfc5764::{Dtls, DtlsBuilder, DtlsHandshakeResult, DtlsMidHandshake, SrtpProtectionProfile}; + +impl DtlsBuilder for SslConnectorBuilder { + type Instance = SslStream; + type MidHandshake = MidHandshakeSslStream; + + fn handshake(mut self, stream: S) -> DtlsHandshakeResult + where + S: Read + Write, + { + let profiles_str: String = SrtpProtectionProfile::RECOMMENDED + .iter() + .map(|profile| profile.name.to_string()) + .collect::>() + .join(":"); + self.set_tlsext_use_srtp(&profiles_str).unwrap(); + match self.build().connect("invalid", stream) { + Ok(stream) => DtlsHandshakeResult::Success(stream), + Err(HandshakeError::WouldBlock(mid_handshake)) => { + DtlsHandshakeResult::WouldBlock(mid_handshake) + } + Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure( + io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()), + ), + Err(HandshakeError::SetupFailure(err)) => { + DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err)) + } + } + } +} + +impl DtlsBuilder for SslAcceptorBuilder { + type Instance = SslStream; + type MidHandshake = MidHandshakeSslStream; + + fn handshake(mut self, stream: S) -> DtlsHandshakeResult + where + S: Read + Write, + { + let profiles_str: String = SrtpProtectionProfile::RECOMMENDED + .iter() + .map(|profile| profile.name.to_string()) + .collect::>() + .join(":"); + self.set_tlsext_use_srtp(&profiles_str).unwrap(); + match self.build().accept(stream) { + Ok(stream) => DtlsHandshakeResult::Success(stream), + Err(HandshakeError::WouldBlock(mid_handshake)) => { + DtlsHandshakeResult::WouldBlock(mid_handshake) + } + Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure( + io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()), + ), + Err(HandshakeError::SetupFailure(err)) => { + DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err)) + } + } + } +} + +impl DtlsMidHandshake for MidHandshakeSslStream { + type Instance = SslStream; + + fn handshake(self) -> DtlsHandshakeResult { + match MidHandshakeSslStream::handshake(self) { + Ok(stream) => DtlsHandshakeResult::Success(stream), + Err(HandshakeError::WouldBlock(mid_handshake)) => { + DtlsHandshakeResult::WouldBlock(mid_handshake) + } + Err(HandshakeError::Failure(mid_handshake)) => DtlsHandshakeResult::Failure( + io::Error::new(io::ErrorKind::Other, mid_handshake.into_error()), + ), + Err(HandshakeError::SetupFailure(err)) => { + DtlsHandshakeResult::Failure(io::Error::new(io::ErrorKind::Other, err)) + } + } + } +} + +impl Dtls for SslStream { + fn is_server_side(&self) -> bool { + self.ssl().is_server() + } + + fn export_key(&mut self, exporter_label: &str, length: usize) -> Vec { + let mut vec = vec![0; length]; + self.ssl() + .export_keying_material(&mut vec, exporter_label, None) + .unwrap(); + vec + } +} + +#[cfg(test)] +mod test { + use rfc5764::test::DummyTransport; + use rfc5764::{DtlsSrtp, DtlsSrtpHandshakeResult}; + + use openssl::asn1::Asn1Time; + use openssl::hash::MessageDigest; + use openssl::pkey::PKey; + use openssl::rsa::Rsa; + use openssl::ssl::{SslAcceptor, SslConnector, SslMethod, SslVerifyMode}; + use openssl::x509::X509; + + #[test] + fn connect_and_establish_matching_key_material() { + let (client_sock, server_sock) = DummyTransport::new(); + + let rsa = Rsa::generate(2048).unwrap(); + let key = PKey::from_rsa(rsa).unwrap(); + + let mut cert_builder = X509::builder().unwrap(); + cert_builder + .set_not_after(&Asn1Time::days_from_now(1).unwrap()) + .unwrap(); + cert_builder + .set_not_before(&Asn1Time::days_from_now(0).unwrap()) + .unwrap(); + cert_builder.set_pubkey(&key).unwrap(); + cert_builder.sign(&key, MessageDigest::sha256()).unwrap(); + let cert = cert_builder.build(); + + let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::dtls()).unwrap(); + let mut connector = SslConnector::builder(SslMethod::dtls()).unwrap(); + acceptor.set_certificate(&cert).unwrap(); + acceptor.set_private_key(&key).unwrap(); + acceptor.set_verify(SslVerifyMode::NONE); + connector.set_verify(SslVerifyMode::NONE); + let mut handshake_server = DtlsSrtp::handshake(server_sock, acceptor); + let mut handshake_client = DtlsSrtp::handshake(client_sock, connector); + + loop { + match (handshake_client, handshake_server) { + (DtlsSrtpHandshakeResult::Failure(err), _) => { + panic!("Client error: {}", err); + } + (_, DtlsSrtpHandshakeResult::Failure(err)) => { + panic!("Server error: {}", err); + } + ( + DtlsSrtpHandshakeResult::Success(client), + DtlsSrtpHandshakeResult::Success(server), + ) => { + assert_eq!( + client.srtp_read_context.master_key, + server.srtp_write_context.master_key + ); + assert_eq!( + client.srtp_read_context.master_salt, + server.srtp_write_context.master_salt + ); + assert_eq!( + client.srtp_write_context.master_key, + server.srtp_read_context.master_key + ); + assert_eq!( + client.srtp_write_context.master_salt, + server.srtp_read_context.master_salt + ); + return; + } + ( + DtlsSrtpHandshakeResult::WouldBlock(client), + DtlsSrtpHandshakeResult::WouldBlock(server), + ) => { + handshake_client = client.handshake(); + handshake_server = server.handshake(); + } + (client, DtlsSrtpHandshakeResult::WouldBlock(server)) => { + handshake_client = client; + handshake_server = server.handshake(); + } + (DtlsSrtpHandshakeResult::WouldBlock(client), server) => { + handshake_client = client.handshake(); + handshake_server = server; + } + } + } + } +} diff --git a/src/rfc5764/tokio.rs b/src/rfc5764/tokio.rs new file mode 100644 index 0000000..5c3c44d --- /dev/null +++ b/src/rfc5764/tokio.rs @@ -0,0 +1,95 @@ +use std::io; +use tokio::prelude::{Async, AsyncRead, AsyncSink, AsyncWrite, Future, Sink, Stream}; + +use rfc5764::{DtlsBuilder, DtlsSrtp, DtlsSrtpHandshakeResult, DtlsSrtpMuxerPart}; + +impl AsyncRead for DtlsSrtp +where + S: AsyncRead + AsyncWrite, + D: DtlsBuilder>, +{ +} + +impl AsyncWrite for DtlsSrtp +where + S: AsyncRead + AsyncWrite, + D: DtlsBuilder>, +{ + fn shutdown(&mut self) -> io::Result> { + Ok(().into()) // FIXME + } +} + +impl Future for DtlsSrtpHandshakeResult +where + S: AsyncRead + AsyncWrite, + D: DtlsBuilder>, +{ + type Item = DtlsSrtp; + type Error = io::Error; + + fn poll(&mut self) -> io::Result> { + let mut owned = DtlsSrtpHandshakeResult::Failure(io::Error::new( + io::ErrorKind::Other, + "poll called after completion", + )); + std::mem::swap(&mut owned, self); + match owned { + DtlsSrtpHandshakeResult::Success(dtls_srtp) => Ok(Async::Ready(dtls_srtp)), + DtlsSrtpHandshakeResult::WouldBlock(mid_handshake) => match mid_handshake.handshake() { + DtlsSrtpHandshakeResult::Success(dtls_srtp) => Ok(Async::Ready(dtls_srtp)), + mut new @ DtlsSrtpHandshakeResult::WouldBlock(_) => { + std::mem::swap(&mut new, self); + Ok(Async::NotReady) + } + DtlsSrtpHandshakeResult::Failure(err) => Err(err), + }, + DtlsSrtpHandshakeResult::Failure(err) => Err(err), + } + } +} + +impl Stream for DtlsSrtp +where + S: AsyncRead + AsyncWrite, + D: DtlsBuilder>, + DtlsSrtp: AsyncRead + AsyncWrite, +{ + type Item = Vec; + type Error = io::Error; + + fn poll(&mut self) -> io::Result>> { + let mut buf = [0; 2048]; + Ok(match self.poll_read(&mut buf)? { + Async::Ready(len) => { + if len == 0 { + Async::Ready(None) + } else { + Async::Ready(Some(buf[..len].to_vec())) + } + } + Async::NotReady => Async::NotReady, + }) + } +} + +impl Sink for DtlsSrtp +where + S: AsyncRead + AsyncWrite, + D: DtlsBuilder>, + DtlsSrtp: AsyncRead + AsyncWrite, +{ + type SinkItem = Vec; + type SinkError = io::Error; + + fn start_send(&mut self, buf: Self::SinkItem) -> io::Result> { + Ok(match self.poll_write(&buf[..])? { + Async::Ready(_) => AsyncSink::Ready, + Async::NotReady => AsyncSink::NotReady(buf), + }) + } + + fn poll_complete(&mut self) -> io::Result> { + self.poll_flush() + } +}