use std::collections::VecDeque;
use puffin::codec;
use crate::tls::rustls::msgs::enums::{ContentType, ProtocolVersion};
use crate::tls::rustls::msgs::handshake::HandshakeMessagePayload;
use crate::tls::rustls::msgs::message::{Message, MessagePayload, PlainMessage};
const HEADER_SIZE: usize = 1 + 3;
const MAX_HANDSHAKE_SIZE: u32 = 0xffff;
pub struct HandshakeJoiner {
pub frames: VecDeque<Message>,
buf: Vec<u8>,
}
impl Default for HandshakeJoiner {
fn default() -> Self {
Self::new()
}
}
enum BufferState {
MessageTooLarge,
OneMessage,
NeedsMoreData,
}
impl HandshakeJoiner {
pub fn new() -> Self {
Self {
frames: VecDeque::new(),
buf: Vec::new(),
}
}
pub fn want_message(&self, msg: &PlainMessage) -> bool {
msg.typ == ContentType::Handshake
}
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
pub fn take_message(&mut self, msg: PlainMessage) -> Option<usize> {
if self.buf.is_empty() {
self.buf = msg.payload.0;
} else {
self.buf.extend_from_slice(&msg.payload.0[..]);
}
let mut count = 0;
loop {
match self.buf_contains_message() {
BufferState::MessageTooLarge => return None,
BufferState::NeedsMoreData => break,
BufferState::OneMessage => {
if !self.deframe_one(msg.version) {
return None;
}
count += 1;
}
}
}
Some(count)
}
fn buf_contains_message(&self) -> BufferState {
if self.buf.len() < HEADER_SIZE {
return BufferState::NeedsMoreData;
}
let (header, rest) = self.buf.split_at(HEADER_SIZE);
match codec::u24::decode(&header[1..]) {
Some(len) if len.0 > MAX_HANDSHAKE_SIZE => BufferState::MessageTooLarge,
Some(len) if rest.get(..len.into()).is_some() => BufferState::OneMessage,
_ => BufferState::NeedsMoreData,
}
}
fn deframe_one(&mut self, version: ProtocolVersion) -> bool {
let used = {
let mut rd = codec::Reader::init(&self.buf);
let payload = match HandshakeMessagePayload::read_version(&mut rd, version) {
Some(p) => p,
None => return false,
};
let m = Message {
version,
payload: MessagePayload::Handshake(payload),
};
self.frames.push_back(m);
rd.used()
};
self.buf = self.buf.split_off(used);
true
}
}
#[cfg(test)]
mod tests {
use puffin::codec::Codec;
use super::HandshakeJoiner;
use crate::tls::rustls::msgs::base::Payload;
use crate::tls::rustls::msgs::enums::{ContentType, HandshakeType, ProtocolVersion};
use crate::tls::rustls::msgs::handshake::{HandshakeMessagePayload, HandshakePayload};
use crate::tls::rustls::msgs::message::{Message, MessagePayload, PlainMessage};
#[test_log::test]
fn want() {
let hj = HandshakeJoiner::new();
assert!(hj.is_empty());
let wanted = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"hello world".to_vec()),
};
let unwanted = PlainMessage {
typ: ContentType::Alert,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"ponytown".to_vec()),
};
assert!(hj.want_message(&wanted));
assert!(!hj.want_message(&unwanted));
}
fn pop_eq(expect: &PlainMessage, hj: &mut HandshakeJoiner) {
let got = hj.frames.pop_front().unwrap();
assert_eq!(got.payload.content_type(), expect.typ);
assert_eq!(got.version, expect.version);
let (mut left, mut right) = (Vec::new(), Vec::new());
got.payload.encode(&mut left);
expect.payload.encode(&mut right);
assert_eq!(left, right);
}
#[test_log::test]
fn split() {
let mut hj = HandshakeJoiner::new();
let msg = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()),
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), Some(2));
assert!(hj.is_empty());
let expect = Message {
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::Handshake(HandshakeMessagePayload {
typ: HandshakeType::HelloRequest,
payload: HandshakePayload::HelloRequest,
}),
}
.into();
pop_eq(&expect, &mut hj);
pop_eq(&expect, &mut hj);
}
#[test_log::test]
fn broken() {
let mut hj = HandshakeJoiner::new();
let msg = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x01\x00\x00\x02\xff\xff".to_vec()),
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), None);
}
#[test_log::test]
fn join() {
let mut hj = HandshakeJoiner::new();
assert!(hj.is_empty());
let mut msg = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x14\x00\x00\x10\x00\x01\x02\x03\x04".to_vec()),
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), Some(0));
assert!(!hj.is_empty());
msg = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec()),
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), Some(0));
assert!(!hj.is_empty());
msg = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x0f".to_vec()),
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), Some(1));
assert!(hj.is_empty());
let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec();
let expect = Message {
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::Handshake(HandshakeMessagePayload {
typ: HandshakeType::Finished,
payload: HandshakePayload::Finished(Payload::new(payload)),
}),
}
.into();
pop_eq(&expect, &mut hj);
}
#[test_log::test]
fn test_rejects_giant_certs() {
let mut hj = HandshakeJoiner::new();
let msg = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x0b\x01\x00\x04\x01\x00\x01\x00\xff\xfe".to_vec()),
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), None);
assert!(!hj.is_empty());
}
}