diff --git a/src/rfc3711.rs b/src/rfc3711.rs index d4e427e..5c09638 100644 --- a/src/rfc3711.rs +++ b/src/rfc3711.rs @@ -40,6 +40,9 @@ pub struct SrtpContext { // TODO: support other fields pub master_key: Vec, pub master_salt: Vec, + // Since actual kdr is a power of two, this only stores the power (+1). + // i.e. actual kdr is 2^(key_derivation_rate-1) (or 0 in case of 0) + pub key_derivation_rate: u8, pub rollover_counter: u32, pub highest_recv_seq_num: u16, pub encryption: EncryptionAlgorithm, @@ -56,6 +59,7 @@ impl SrtpContext { SrtpContext { master_key: Vec::from(master_key), master_salt: Vec::from(master_salt), + key_derivation_rate: 0, rollover_counter: 0, highest_recv_seq_num: 0, encryption: EncryptionAlgorithm::default(), @@ -67,8 +71,16 @@ impl SrtpContext { auth_tag_len: 80 / 8, } } - pub fn update_session_keys(&mut self) { - let index = ((self.rollover_counter as u64) << 16) + self.highest_recv_seq_num as u64; + pub fn update_session_keys(&mut self, index: u64) { + let index = if self.key_derivation_rate == 0 { + 0 + } else { + index >> (self.key_derivation_rate - 1) + }; + + // TODO: only recalculate if index changed, probably also cache surrounding indices + // but make sure the initial updates happens + let index = BigUint::from(index); let enc_key_id = BigUint::from_bytes_be(&[0, 0, 0, 0, 0, 0, 0]) + index.clone(); @@ -154,6 +166,45 @@ impl SrtpContext { }; (probable_roc, ((probable_roc as PacketIndex) << 16) + seq_num as PacketIndex) } + + // https://tools.ietf.org/html/rfc3711#section-3.3 + pub fn process_incoming(&mut self, packet: &[u8]) -> Result> { + // Step 1: determining the correct context (has already happened at this point) + + // Step 2: Determine index of the SRTP packet + let reader = &mut &packet[..]; + let header = track_try!(rfc3550::RtpFixedHeader::read_from(reader)); + let seq_num = header.seq_num; + let (rollover_counter, index) = self.estimate_packet_index(seq_num); + + // Step 3: Determine master key and salt + // TODO: support re-keying + // TODO: support MKI + + // Step 4: Determine session keys and salt + self.update_session_keys(index); + + // Step 5: Replay protection and authentication + // TODO: replay protection + self.authenticate(packet)?; + + // Step 6: Decryption + let result = self.decrypt(packet)?; + + // Step 7: Update ROC, highest sequence number and replay protection + // TODO: replay protection + // https://tools.ietf.org/html/rfc3711#section-3.3.1 + if rollover_counter == self.rollover_counter { + if seq_num > self.highest_recv_seq_num { + self.highest_recv_seq_num = seq_num; + } + } else if rollover_counter > self.rollover_counter { + self.highest_recv_seq_num = seq_num; + self.rollover_counter = rollover_counter; + } + + Ok(result) + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -161,6 +212,9 @@ pub struct SrtcpContext { // TODO: support other fields pub master_key: Vec, pub master_salt: Vec, + // Since actual kdr is a power of two, this only stores the power (+1). + // i.e. actual kdr is 2^(key_derivation_rate-1) (or 0 in case of 0) + pub key_derivation_rate: u8, pub highest_recv_index: u32, // NOTE: 31-bits pub encryption: EncryptionAlgorithm, pub authentication: AuthenticationAlgorithm, @@ -176,6 +230,7 @@ impl SrtcpContext { SrtcpContext { master_key: Vec::from(master_key), master_salt: Vec::from(master_salt), + key_derivation_rate: 0, highest_recv_index: 0, encryption: EncryptionAlgorithm::default(), authentication: AuthenticationAlgorithm::default(), @@ -186,9 +241,18 @@ impl SrtcpContext { auth_tag_len: 80 / 8, } } - pub fn update_session_keys(&mut self) { + pub fn update_session_keys(&mut self, index: u32) { + let index = if self.key_derivation_rate == 0 { + 0 + } else { + index >> (self.key_derivation_rate - 1) + }; + + // TODO: only recalculate if index changed, probably also cache surrounding indices + // but make sure the initial updates happens + // See: https://tools.ietf.org/html/rfc3711#section-4.3.2 - let index = BigUint::from(self.highest_recv_index); + let index = BigUint::from(index); let enc_key_id = BigUint::from_bytes_be(&[3, 0, 0, 0, 0, 0, 0]) + index.clone(); let auth_key_id = BigUint::from_bytes_be(&[4, 0, 0, 0, 0, 0, 0]) + index.clone(); @@ -251,6 +315,35 @@ impl SrtcpContext { Ok(decrypted) } + + // https://tools.ietf.org/html/rfc3711#section-3.3 + // https://tools.ietf.org/html/rfc3711#section-3.4 + pub fn process_incoming(&mut self, packet: &[u8]) -> Result> { + // Step 1: determining the correct context (has already happened at this point) + + // Step 2: Determine index of the SRTCP packet + let index = track_try!((&mut &packet[packet.len() - self.auth_tag_len - 4..]).read_u32be()); + let index = index & 0x7FFF_FFFF; // remove uppermost bit which isn't part of the index + + // Step 3: Determine master key and salt + // TODO: support re-keying + // TODO: support MKI + + // Step 4: Determine session keys and salt + self.update_session_keys(index); + + // Step 5: Replay protection and authentication + // TODO: replay protection + track_try!(self.authenticate(packet)); + + // Step 6: Decryption + let result = track_try!(self.decrypt(packet)); + + // Step 7: Update replay protection + // TODO: replay protection + + Ok(result) + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -264,7 +357,6 @@ where T::Packet: RtpPacket, { pub fn new(mut context: SrtpContext, inner: T) -> Self { - context.update_session_keys(); SrtpPacketReader { context: context, inner: inner, @@ -279,8 +371,7 @@ where type Packet = T::Packet; fn read_packet(&mut self, reader: &mut R) -> Result { let packet_bytes = track_try!(reader.read_all_bytes()); - track_try!(self.context.authenticate(&packet_bytes)); - let decrypted_packet_bytes = track_try!(self.context.decrypt(&packet_bytes)); + let decrypted_packet_bytes = track_try!(self.context.process_incoming(&packet_bytes)); track_err!(self.inner.read_packet(&mut &decrypted_packet_bytes[..])) } @@ -300,7 +391,6 @@ where T::Packet: RtcpPacket, { pub fn new(mut context: SrtcpContext, inner: T) -> Self { - context.update_session_keys(); SrtcpPacketReader { context: context, inner: inner, @@ -315,8 +405,7 @@ where type Packet = T::Packet; fn read_packet(&mut self, reader: &mut R) -> Result { let packet_bytes = track_try!(reader.read_all_bytes()); - track_try!(self.context.authenticate(&packet_bytes)); - let decrypted_packet_bytes = track_try!(self.context.decrypt(&packet_bytes)); + let decrypted_packet_bytes = track_try!(self.context.process_incoming(&packet_bytes)); track_err!(self.inner.read_packet(&mut &decrypted_packet_bytes[..])) }