use puffin::codec::{Codec, Reader};
use crate::tls::rustls::error::Error;
use crate::tls::rustls::msgs::alert::AlertMessagePayload;
use crate::tls::rustls::msgs::base::Payload;
use crate::tls::rustls::msgs::ccs::ChangeCipherSpecPayload;
use crate::tls::rustls::msgs::enums::{
AlertDescription, AlertLevel, ContentType, HandshakeType, ProtocolVersion,
};
use crate::tls::rustls::msgs::handshake::HandshakeMessagePayload;
use crate::tls::rustls::msgs::heartbeat::HeartbeatPayload;
#[derive(Debug, Clone)]
pub enum MessagePayload {
Alert(AlertMessagePayload),
Handshake(HandshakeMessagePayload),
TLS12EncryptedHandshake(Payload),
ChangeCipherSpec(ChangeCipherSpecPayload),
ApplicationData(Payload),
Heartbeat(HeartbeatPayload),
}
impl MessagePayload {
pub fn encode(&self, bytes: &mut Vec<u8>) {
match *self {
Self::Alert(ref x) => x.encode(bytes),
Self::Handshake(ref x) => x.encode(bytes),
Self::TLS12EncryptedHandshake(ref x) => x.encode(bytes),
Self::ChangeCipherSpec(ref x) => x.encode(bytes),
Self::ApplicationData(ref x) => x.encode(bytes),
Self::Heartbeat(ref x) => x.encode(bytes),
}
}
pub fn new(typ: ContentType, vers: ProtocolVersion, payload: Payload) -> Result<Self, Error> {
let fallback_payload = payload.clone();
let mut r = Reader::init(&payload.0);
let parsed = match typ {
ContentType::ApplicationData => return Ok(Self::ApplicationData(payload)),
ContentType::Alert => AlertMessagePayload::read(&mut r).map(MessagePayload::Alert),
ContentType::Handshake => {
HandshakeMessagePayload::read_version(&mut r, vers)
.map(MessagePayload::Handshake)
.or(Some(MessagePayload::TLS12EncryptedHandshake(
fallback_payload,
)))
}
ContentType::ChangeCipherSpec => {
ChangeCipherSpecPayload::read(&mut r).map(MessagePayload::ChangeCipherSpec)
}
ContentType::Heartbeat => HeartbeatPayload::read(&mut r).map(MessagePayload::Heartbeat),
_ => None,
};
parsed.ok_or(Error::corrupt_message(typ))
}
pub fn multiple_new(
typ: ContentType,
vers: ProtocolVersion,
payload: Payload,
) -> Result<Vec<Self>, Error> {
let fallback_payload = &payload;
let mut r = Reader::init(&payload.0);
let mut parsed: Vec<Self> = vec![];
while r.any_left() {
let parsed_msg = match typ {
ContentType::ApplicationData => Some(Self::ApplicationData(payload.clone())),
ContentType::Alert => AlertMessagePayload::read(&mut r).map(MessagePayload::Alert),
ContentType::Handshake => {
HandshakeMessagePayload::read_version(&mut r, vers)
.map(MessagePayload::Handshake)
.or(Some(MessagePayload::TLS12EncryptedHandshake(
fallback_payload.clone(),
)))
}
ContentType::ChangeCipherSpec => {
ChangeCipherSpecPayload::read(&mut r).map(MessagePayload::ChangeCipherSpec)
}
ContentType::Heartbeat => {
HeartbeatPayload::read(&mut r).map(MessagePayload::Heartbeat)
}
_ => None,
};
if let Some(msg) = parsed_msg {
parsed.push(msg);
}
}
Ok(parsed)
}
pub fn content_type(&self) -> ContentType {
match self {
Self::Alert(_) => ContentType::Alert,
Self::Handshake(_) => ContentType::Handshake,
Self::TLS12EncryptedHandshake(_) => ContentType::Handshake,
Self::ChangeCipherSpec(_) => ContentType::ChangeCipherSpec,
Self::ApplicationData(_) => ContentType::ApplicationData,
Self::Heartbeat(_) => ContentType::Heartbeat,
}
}
}
#[derive(Clone, Debug)]
pub struct OpaqueMessage {
pub typ: ContentType,
pub version: ProtocolVersion,
pub payload: Payload,
}
impl Codec for OpaqueMessage {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.extend_from_slice(&OpaqueMessage::encode(self.clone()));
}
fn read(reader: &mut Reader) -> Option<Self> {
Self::read(reader).ok()
}
}
impl OpaqueMessage {
const HEADER_SIZE: u16 = 1 + 2 + 2;
const MAX_PAYLOAD: u16 = 16384 + 2048;
pub const MAX_WIRE_SIZE: usize = (Self::MAX_PAYLOAD + Self::HEADER_SIZE) as usize;
pub fn read(r: &mut Reader) -> Result<Self, MessageError> {
let typ = ContentType::read(r).ok_or(MessageError::TooShortForHeader)?;
let version = ProtocolVersion::read(r).ok_or(MessageError::TooShortForHeader)?;
let len = u16::read(r).ok_or(MessageError::TooShortForHeader)?;
if typ != ContentType::ApplicationData && len == 0 {
return Err(MessageError::IllegalLength);
}
if len >= Self::MAX_PAYLOAD {
return Err(MessageError::IllegalLength);
}
if let ContentType::Unknown(_) = typ {
return Err(MessageError::IllegalContentType);
}
match version {
ProtocolVersion::Unknown(ref v) if (v & 0xff00) != 0x0300 => {
return Err(MessageError::IllegalProtocolVersion);
}
_ => {}
};
let mut sub = r.sub(len as usize).ok_or(MessageError::TooShortForLength)?;
let payload = Payload::read(&mut sub);
Ok(Self {
typ,
version,
payload,
})
}
pub fn encode(self) -> Vec<u8> {
let mut buf = Vec::new();
self.typ.encode(&mut buf);
self.version.encode(&mut buf);
(self.payload.0.len() as u16).encode(&mut buf);
self.payload.encode(&mut buf);
buf
}
pub fn into_plain_message(self) -> PlainMessage {
PlainMessage {
version: self.version,
typ: self.typ,
payload: self.payload,
}
}
}
impl From<Message> for PlainMessage {
fn from(msg: Message) -> Self {
let typ = msg.payload.content_type();
let payload = match msg.payload {
MessagePayload::ApplicationData(payload) => payload,
_ => {
let mut buf = Vec::new();
msg.payload.encode(&mut buf);
Payload(buf)
}
};
Self {
typ,
version: msg.version,
payload,
}
}
}
#[derive(Clone, Debug)]
pub struct PlainMessage {
pub typ: ContentType,
pub version: ProtocolVersion,
pub payload: Payload,
}
impl PlainMessage {
pub fn into_unencrypted_opaque(self) -> OpaqueMessage {
OpaqueMessage {
version: self.version,
typ: self.typ,
payload: self.payload,
}
}
pub fn borrow(&self) -> BorrowedPlainMessage<'_> {
BorrowedPlainMessage {
version: self.version,
typ: self.typ,
payload: &self.payload.0,
}
}
}
#[derive(Debug, Clone)]
pub struct Message {
pub version: ProtocolVersion,
pub payload: MessagePayload,
}
impl Message {
pub fn is_handshake_type(&self, hstyp: HandshakeType) -> bool {
if let MessagePayload::Handshake(ref hsp) = self.payload {
hsp.typ == hstyp
} else {
false
}
}
pub fn build_alert(level: AlertLevel, desc: AlertDescription) -> Self {
Self {
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::Alert(AlertMessagePayload {
level,
description: desc,
}),
}
}
pub fn build_key_update_notify() -> Self {
Self {
version: ProtocolVersion::TLSv1_3,
payload: MessagePayload::Handshake(HandshakeMessagePayload::build_key_update_notify()),
}
}
}
impl TryFrom<PlainMessage> for Message {
type Error = Error;
fn try_from(plain: PlainMessage) -> Result<Self, Self::Error> {
Ok(Self {
version: plain.version,
payload: MessagePayload::new(plain.typ, plain.version, plain.payload)?,
})
}
}
impl TryFrom<OpaqueMessage> for Message {
type Error = Error;
fn try_from(value: OpaqueMessage) -> Result<Self, Self::Error> {
Message::try_from(value.into_plain_message())
}
}
pub struct BorrowedPlainMessage<'a> {
pub typ: ContentType,
pub version: ProtocolVersion,
pub payload: &'a [u8],
}
#[derive(Debug)]
pub enum MessageError {
TooShortForHeader,
TooShortForLength,
IllegalLength,
IllegalContentType,
IllegalProtocolVersion,
}