Support key_derivation_rate and fix update_session_keys only working by accident

dtls-srtp
Jonas Herzig 2018-11-08 17:15:15 +01:00
parent febcc71ac8
commit 0b5229d36f
1 changed files with 99 additions and 10 deletions

View File

@ -40,6 +40,9 @@ pub struct SrtpContext {
// TODO: support other fields // TODO: support other fields
pub master_key: Vec<u8>, pub master_key: Vec<u8>,
pub master_salt: Vec<u8>, pub master_salt: Vec<u8>,
// Since actual kdr is a power of two, this only stores the power (+1).
// i.e. actual kdr is 2^(key_derivation_rate-1) (or 0 in case of 0)
pub key_derivation_rate: u8,
pub rollover_counter: u32, pub rollover_counter: u32,
pub highest_recv_seq_num: u16, pub highest_recv_seq_num: u16,
pub encryption: EncryptionAlgorithm, pub encryption: EncryptionAlgorithm,
@ -56,6 +59,7 @@ impl SrtpContext {
SrtpContext { SrtpContext {
master_key: Vec::from(master_key), master_key: Vec::from(master_key),
master_salt: Vec::from(master_salt), master_salt: Vec::from(master_salt),
key_derivation_rate: 0,
rollover_counter: 0, rollover_counter: 0,
highest_recv_seq_num: 0, highest_recv_seq_num: 0,
encryption: EncryptionAlgorithm::default(), encryption: EncryptionAlgorithm::default(),
@ -67,8 +71,16 @@ impl SrtpContext {
auth_tag_len: 80 / 8, auth_tag_len: 80 / 8,
} }
} }
pub fn update_session_keys(&mut self) { pub fn update_session_keys(&mut self, index: u64) {
let index = ((self.rollover_counter as u64) << 16) + self.highest_recv_seq_num as u64; let index = if self.key_derivation_rate == 0 {
0
} else {
index >> (self.key_derivation_rate - 1)
};
// TODO: only recalculate if index changed, probably also cache surrounding indices
// but make sure the initial updates happens
let index = BigUint::from(index); let index = BigUint::from(index);
let enc_key_id = BigUint::from_bytes_be(&[0, 0, 0, 0, 0, 0, 0]) + index.clone(); let enc_key_id = BigUint::from_bytes_be(&[0, 0, 0, 0, 0, 0, 0]) + index.clone();
@ -154,6 +166,45 @@ impl SrtpContext {
}; };
(probable_roc, ((probable_roc as PacketIndex) << 16) + seq_num as PacketIndex) (probable_roc, ((probable_roc as PacketIndex) << 16) + seq_num as PacketIndex)
} }
// https://tools.ietf.org/html/rfc3711#section-3.3
pub fn process_incoming(&mut self, packet: &[u8]) -> Result<Vec<u8>> {
// Step 1: determining the correct context (has already happened at this point)
// Step 2: Determine index of the SRTP packet
let reader = &mut &packet[..];
let header = track_try!(rfc3550::RtpFixedHeader::read_from(reader));
let seq_num = header.seq_num;
let (rollover_counter, index) = self.estimate_packet_index(seq_num);
// Step 3: Determine master key and salt
// TODO: support re-keying
// TODO: support MKI
// Step 4: Determine session keys and salt
self.update_session_keys(index);
// Step 5: Replay protection and authentication
// TODO: replay protection
self.authenticate(packet)?;
// Step 6: Decryption
let result = self.decrypt(packet)?;
// Step 7: Update ROC, highest sequence number and replay protection
// TODO: replay protection
// https://tools.ietf.org/html/rfc3711#section-3.3.1
if rollover_counter == self.rollover_counter {
if seq_num > self.highest_recv_seq_num {
self.highest_recv_seq_num = seq_num;
}
} else if rollover_counter > self.rollover_counter {
self.highest_recv_seq_num = seq_num;
self.rollover_counter = rollover_counter;
}
Ok(result)
}
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@ -161,6 +212,9 @@ pub struct SrtcpContext {
// TODO: support other fields // TODO: support other fields
pub master_key: Vec<u8>, pub master_key: Vec<u8>,
pub master_salt: Vec<u8>, pub master_salt: Vec<u8>,
// Since actual kdr is a power of two, this only stores the power (+1).
// i.e. actual kdr is 2^(key_derivation_rate-1) (or 0 in case of 0)
pub key_derivation_rate: u8,
pub highest_recv_index: u32, // NOTE: 31-bits pub highest_recv_index: u32, // NOTE: 31-bits
pub encryption: EncryptionAlgorithm, pub encryption: EncryptionAlgorithm,
pub authentication: AuthenticationAlgorithm, pub authentication: AuthenticationAlgorithm,
@ -176,6 +230,7 @@ impl SrtcpContext {
SrtcpContext { SrtcpContext {
master_key: Vec::from(master_key), master_key: Vec::from(master_key),
master_salt: Vec::from(master_salt), master_salt: Vec::from(master_salt),
key_derivation_rate: 0,
highest_recv_index: 0, highest_recv_index: 0,
encryption: EncryptionAlgorithm::default(), encryption: EncryptionAlgorithm::default(),
authentication: AuthenticationAlgorithm::default(), authentication: AuthenticationAlgorithm::default(),
@ -186,9 +241,18 @@ impl SrtcpContext {
auth_tag_len: 80 / 8, auth_tag_len: 80 / 8,
} }
} }
pub fn update_session_keys(&mut self) { pub fn update_session_keys(&mut self, index: u32) {
let index = if self.key_derivation_rate == 0 {
0
} else {
index >> (self.key_derivation_rate - 1)
};
// TODO: only recalculate if index changed, probably also cache surrounding indices
// but make sure the initial updates happens
// See: https://tools.ietf.org/html/rfc3711#section-4.3.2 // See: https://tools.ietf.org/html/rfc3711#section-4.3.2
let index = BigUint::from(self.highest_recv_index); let index = BigUint::from(index);
let enc_key_id = BigUint::from_bytes_be(&[3, 0, 0, 0, 0, 0, 0]) + index.clone(); let enc_key_id = BigUint::from_bytes_be(&[3, 0, 0, 0, 0, 0, 0]) + index.clone();
let auth_key_id = BigUint::from_bytes_be(&[4, 0, 0, 0, 0, 0, 0]) + index.clone(); let auth_key_id = BigUint::from_bytes_be(&[4, 0, 0, 0, 0, 0, 0]) + index.clone();
@ -251,6 +315,35 @@ impl SrtcpContext {
Ok(decrypted) Ok(decrypted)
} }
// https://tools.ietf.org/html/rfc3711#section-3.3
// https://tools.ietf.org/html/rfc3711#section-3.4
pub fn process_incoming(&mut self, packet: &[u8]) -> Result<Vec<u8>> {
// Step 1: determining the correct context (has already happened at this point)
// Step 2: Determine index of the SRTCP packet
let index = track_try!((&mut &packet[packet.len() - self.auth_tag_len - 4..]).read_u32be());
let index = index & 0x7FFF_FFFF; // remove uppermost bit which isn't part of the index
// Step 3: Determine master key and salt
// TODO: support re-keying
// TODO: support MKI
// Step 4: Determine session keys and salt
self.update_session_keys(index);
// Step 5: Replay protection and authentication
// TODO: replay protection
track_try!(self.authenticate(packet));
// Step 6: Decryption
let result = track_try!(self.decrypt(packet));
// Step 7: Update replay protection
// TODO: replay protection
Ok(result)
}
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@ -264,7 +357,6 @@ where
T::Packet: RtpPacket, T::Packet: RtpPacket,
{ {
pub fn new(mut context: SrtpContext, inner: T) -> Self { pub fn new(mut context: SrtpContext, inner: T) -> Self {
context.update_session_keys();
SrtpPacketReader { SrtpPacketReader {
context: context, context: context,
inner: inner, inner: inner,
@ -279,8 +371,7 @@ where
type Packet = T::Packet; type Packet = T::Packet;
fn read_packet<R: Read>(&mut self, reader: &mut R) -> Result<Self::Packet> { fn read_packet<R: Read>(&mut self, reader: &mut R) -> Result<Self::Packet> {
let packet_bytes = track_try!(reader.read_all_bytes()); let packet_bytes = track_try!(reader.read_all_bytes());
track_try!(self.context.authenticate(&packet_bytes)); let decrypted_packet_bytes = track_try!(self.context.process_incoming(&packet_bytes));
let decrypted_packet_bytes = track_try!(self.context.decrypt(&packet_bytes));
track_err!(self.inner.read_packet(&mut &decrypted_packet_bytes[..])) track_err!(self.inner.read_packet(&mut &decrypted_packet_bytes[..]))
} }
@ -300,7 +391,6 @@ where
T::Packet: RtcpPacket, T::Packet: RtcpPacket,
{ {
pub fn new(mut context: SrtcpContext, inner: T) -> Self { pub fn new(mut context: SrtcpContext, inner: T) -> Self {
context.update_session_keys();
SrtcpPacketReader { SrtcpPacketReader {
context: context, context: context,
inner: inner, inner: inner,
@ -315,8 +405,7 @@ where
type Packet = T::Packet; type Packet = T::Packet;
fn read_packet<R: Read>(&mut self, reader: &mut R) -> Result<Self::Packet> { fn read_packet<R: Read>(&mut self, reader: &mut R) -> Result<Self::Packet> {
let packet_bytes = track_try!(reader.read_all_bytes()); let packet_bytes = track_try!(reader.read_all_bytes());
track_try!(self.context.authenticate(&packet_bytes)); let decrypted_packet_bytes = track_try!(self.context.process_incoming(&packet_bytes));
let decrypted_packet_bytes = track_try!(self.context.decrypt(&packet_bytes));
track_err!(self.inner.read_packet(&mut &decrypted_packet_bytes[..])) track_err!(self.inner.read_packet(&mut &decrypted_packet_bytes[..]))
} }