use std::any::TypeId;
use puffin::codec;
use puffin::codec::{Codec, Reader, VecCodecWoSize};
use puffin::error::Error::Term;
use puffin::protocol::{EvaluatedTerm, ProtocolMessage};
use crate::protocol::{MessageFlight, OpaqueMessageFlight, TLSProtocolTypes};
use crate::tls::rustls::error::Error;
use crate::tls::rustls::hash_hs::HandshakeHash;
use crate::tls::rustls::key::Certificate;
use crate::tls::rustls::msgs::alert::AlertMessagePayload;
use crate::tls::rustls::msgs::base::{Payload, PayloadU16, PayloadU24, PayloadU8};
use crate::tls::rustls::msgs::ccs::ChangeCipherSpecPayload;
use crate::tls::rustls::msgs::enums::ContentType::ApplicationData;
use crate::tls::rustls::msgs::enums::ProtocolVersion::TLSv1_3;
use crate::tls::rustls::msgs::enums::{
AlertDescription, AlertLevel, CipherSuite, Compression, ContentType, HandshakeType, NamedGroup,
ProtocolVersion, SignatureScheme,
};
use crate::tls::rustls::msgs::handshake::{
CertReqExtension, CertificateEntries, CertificateEntry, CertificateExtension, CipherSuites,
ClientExtension, ClientExtensions, Compressions, HandshakeMessagePayload, HelloRetryExtension,
HelloRetryExtensions, NewSessionTicketExtension, NewSessionTicketExtensions,
PresharedKeyIdentity, Random, ServerExtension, ServerExtensions, SessionID, VecU16OfPayloadU16,
VecU16OfPayloadU8,
};
use crate::tls::rustls::msgs::heartbeat::HeartbeatPayload;
#[derive(Debug, Clone)]
pub enum MessagePayload {
Alert(AlertMessagePayload),
Handshake(HandshakeMessagePayload),
TLS12EncryptedHandshake(Payload),
ChangeCipherSpec(ChangeCipherSpecPayload),
ApplicationData(Payload),
Heartbeat(HeartbeatPayload),
}
impl codec::CodecP for MessagePayload {
fn encode(&self, bytes: &mut Vec<u8>) {
MessagePayload::encode(self, bytes);
}
fn read(&mut self, _: &mut Reader) -> Result<(), puffin::error::Error> {
Err(puffin::error::Error::Term(format!(
"Failed to read for type {:?}",
std::any::type_name::<MessagePayload>()
)))
}
}
impl MessagePayload {
pub fn encode(&self, bytes: &mut Vec<u8>) {
match *self {
Self::Alert(ref x) => x.encode(bytes),
Self::Handshake(ref x) => x.encode(bytes),
Self::TLS12EncryptedHandshake(ref x) => x.encode(bytes),
Self::ChangeCipherSpec(ref x) => x.encode(bytes),
Self::ApplicationData(ref x) => x.encode(bytes),
Self::Heartbeat(ref x) => x.encode(bytes),
}
}
pub fn new(typ: ContentType, vers: ProtocolVersion, payload: Payload) -> Result<Self, Error> {
let fallback_payload = payload.clone();
let mut r = Reader::init(&payload.0);
let parsed = match typ {
ContentType::ApplicationData => return Ok(Self::ApplicationData(payload)),
ContentType::Alert => AlertMessagePayload::read(&mut r).map(MessagePayload::Alert),
ContentType::Handshake => {
HandshakeMessagePayload::read_version(&mut r, vers)
.map(MessagePayload::Handshake)
.or(Some(MessagePayload::TLS12EncryptedHandshake(
fallback_payload,
)))
}
ContentType::ChangeCipherSpec => {
ChangeCipherSpecPayload::read(&mut r).map(MessagePayload::ChangeCipherSpec)
}
ContentType::Heartbeat => HeartbeatPayload::read(&mut r).map(MessagePayload::Heartbeat),
_ => None,
};
parsed.ok_or(Error::corrupt_message(typ))
}
pub fn multiple_new(
typ: ContentType,
vers: ProtocolVersion,
payload: Payload,
) -> Result<Vec<Self>, Error> {
let fallback_payload = &payload;
let mut r = Reader::init(&payload.0);
let mut parsed: Vec<Self> = vec![];
while r.any_left() {
let parsed_msg = match typ {
ContentType::ApplicationData => Some(Self::ApplicationData(payload.clone())),
ContentType::Alert => AlertMessagePayload::read(&mut r).map(MessagePayload::Alert),
ContentType::Handshake => {
HandshakeMessagePayload::read_version(&mut r, vers)
.map(MessagePayload::Handshake)
.or(Some(MessagePayload::TLS12EncryptedHandshake(
fallback_payload.clone(),
)))
}
ContentType::ChangeCipherSpec => {
ChangeCipherSpecPayload::read(&mut r).map(MessagePayload::ChangeCipherSpec)
}
ContentType::Heartbeat => {
HeartbeatPayload::read(&mut r).map(MessagePayload::Heartbeat)
}
_ => None,
};
if let Some(msg) = parsed_msg {
parsed.push(msg);
}
}
Ok(parsed)
}
pub fn content_type(&self) -> ContentType {
match self {
Self::Alert(_) => ContentType::Alert,
Self::Handshake(_) => ContentType::Handshake,
Self::TLS12EncryptedHandshake(_) => ContentType::Handshake,
Self::ChangeCipherSpec(_) => ContentType::ChangeCipherSpec,
Self::ApplicationData(_) => ContentType::ApplicationData,
Self::Heartbeat(_) => ContentType::Heartbeat,
}
}
}
#[derive(Clone, Debug)]
pub struct OpaqueMessage {
pub typ: ContentType,
pub version: ProtocolVersion,
pub payload: Payload,
}
impl Codec for OpaqueMessage {
fn encode(&self, bytes: &mut Vec<u8>) {
self.typ.encode(bytes);
self.version.encode(bytes);
(self.payload.0.len() as u16).encode(bytes);
self.payload.encode(bytes);
}
fn read(reader: &mut Reader) -> Option<Self> {
Self::read(reader).ok()
}
}
impl OpaqueMessage {
const HEADER_SIZE: u16 = 1 + 2 + 2;
const MAX_PAYLOAD: u16 = 16384 + 2048;
pub const MAX_WIRE_SIZE: usize = (Self::MAX_PAYLOAD + Self::HEADER_SIZE) as usize;
pub fn read(r: &mut Reader) -> Result<Self, MessageError> {
#[cfg(not(feature = "enable-guards"))]
let typ = ContentType::read(r).unwrap_or(ApplicationData);
#[cfg(not(feature = "enable-guards"))]
let version = ProtocolVersion::read(r).unwrap_or(TLSv1_3);
#[cfg(feature = "enable-guards")]
let typ = ContentType::read(r).ok_or(MessageError::TooShortForHeader)?;
#[cfg(feature = "enable-guards")]
let version = ProtocolVersion::read(r).ok_or(MessageError::TooShortForHeader)?;
let len = u16::read(r).ok_or(MessageError::TooShortForHeader)?;
#[cfg(feature = "enable-guards")]
if typ != ContentType::ApplicationData && len == 0 {
return Err(MessageError::IllegalLength);
}
#[cfg(feature = "enable-guards")]
if len >= Self::MAX_PAYLOAD {
return Err(MessageError::IllegalLength);
}
#[cfg(feature = "enable-guards")]
if let ContentType::Unknown(_) = typ {
return Err(MessageError::IllegalContentType);
}
#[cfg(feature = "enable-guards")]
match version {
ProtocolVersion::Unknown(ref v) if (v & 0xff00) != 0x0300 => {
return Err(MessageError::IllegalProtocolVersion);
}
_ => {}
};
let mut sub = r.sub(len as usize).ok_or(MessageError::TooShortForLength)?;
let payload = Payload::read(&mut sub);
Ok(Self {
typ,
version,
payload,
})
}
pub fn into_plain_message(self) -> PlainMessage {
PlainMessage {
version: self.version,
typ: self.typ,
payload: self.payload,
}
}
}
impl From<Message> for PlainMessage {
fn from(msg: Message) -> Self {
let typ = msg.payload.content_type();
let payload = match msg.payload {
MessagePayload::ApplicationData(payload) => payload,
_ => {
let mut buf = Vec::new();
msg.payload.encode(&mut buf);
Payload(buf)
}
};
Self {
typ,
version: msg.version,
payload,
}
}
}
#[derive(Clone, Debug)]
pub struct PlainMessage {
pub typ: ContentType,
pub version: ProtocolVersion,
pub payload: Payload,
}
impl PlainMessage {
pub fn into_unencrypted_opaque(self) -> OpaqueMessage {
OpaqueMessage {
version: self.version,
typ: self.typ,
payload: self.payload,
}
}
pub fn borrow(&self) -> BorrowedPlainMessage<'_> {
BorrowedPlainMessage {
version: self.version,
typ: self.typ,
payload: &self.payload.0,
}
}
}
#[derive(Debug, Clone)]
pub struct Message {
pub version: ProtocolVersion,
pub payload: MessagePayload,
}
impl VecCodecWoSize for Message {}
impl Codec for Message {
fn encode(&self, bytes: &mut Vec<u8>) {
Codec::encode(&self.create_opaque(), bytes);
}
fn read(reader: &mut Reader) -> Option<Self> {
<OpaqueMessage>::read(reader)
.ok()
.and_then(|op| Message::try_from(op).ok())
}
}
impl Message {
pub fn is_handshake_type(&self, hstyp: HandshakeType) -> bool {
if let MessagePayload::Handshake(ref hsp) = self.payload {
hsp.typ == hstyp
} else {
false
}
}
pub fn build_alert(level: AlertLevel, desc: AlertDescription) -> Self {
Self {
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::Alert(AlertMessagePayload {
level,
description: desc,
}),
}
}
pub fn build_key_update_notify() -> Self {
Self {
version: ProtocolVersion::TLSv1_3,
payload: MessagePayload::Handshake(HandshakeMessagePayload::build_key_update_notify()),
}
}
}
impl TryFrom<PlainMessage> for Message {
type Error = Error;
fn try_from(plain: PlainMessage) -> Result<Self, Self::Error> {
Ok(Self {
version: plain.version,
payload: MessagePayload::new(plain.typ, plain.version, plain.payload)?,
})
}
}
impl TryFrom<OpaqueMessage> for Message {
type Error = Error;
fn try_from(value: OpaqueMessage) -> Result<Self, Self::Error> {
Message::try_from(value.into_plain_message())
}
}
pub struct BorrowedPlainMessage<'a> {
pub typ: ContentType,
pub version: ProtocolVersion,
pub payload: &'a [u8],
}
#[derive(Debug)]
pub enum MessageError {
TooShortForHeader,
TooShortForLength,
IllegalLength,
IllegalContentType,
IllegalProtocolVersion,
}
impl VecCodecWoSize for ClientExtension {} impl VecCodecWoSize for ServerExtension {} impl VecCodecWoSize for HelloRetryExtension {} impl VecCodecWoSize for CertReqExtension {} impl VecCodecWoSize for CertificateExtension {} impl VecCodecWoSize for NewSessionTicketExtension {} impl VecCodecWoSize for Compression {} impl VecCodecWoSize for Certificate {} impl VecCodecWoSize for CertificateEntry {} impl VecCodecWoSize for CipherSuite {} impl VecCodecWoSize for PresharedKeyIdentity {} #[macro_export]
macro_rules! try_read {
($bitstring:expr, $ti:expr, $T:ty, $($Ts:ty),+) => {
{
if $ti == TypeId::of::<$T>() {
log::trace!("Type match TypeID {:?}...!", core::any::type_name::<$T>());
<$T>::read_bytes($bitstring).ok_or(Term(format!(
"[try_read_bytes] Failed to read to type {:?} the bitstring {:?}",
core::any::type_name::<$T>(),
& $bitstring
)).into()).map(|v| Box::new(v) as Box<dyn EvaluatedTerm<TLSProtocolTypes>>)
} else {
try_read!($bitstring, $ti, $($Ts),+)
}
}
};
($bitstring:expr, $ti:expr, $T:ty ) => {
{
if $ti == TypeId::of::<$T>() {
log::trace!("Type match TypeID {:?}...!", core::any::type_name::<$T>());
<$T>::read_bytes($bitstring).ok_or(Term(format!(
"[try_read_bytes] Failed to read to type {:?} the bitstring {:?}",
core::any::type_name::<$T>(),
& $bitstring
)).into()).map(|v| Box::new(v) as Box<dyn EvaluatedTerm<TLSProtocolTypes>>)
} else {
log::error!("Failed to find a suitable type with typeID {:?} to read the bitstring {:?}", $ti, &$bitstring);
Err(Term(format!(
"[try_read_bytes] Failed to find a suitable type with typeID {:?} to read the bitstring {:?}",
$ti,
&$bitstring
)).into())
}
}
};
}
pub fn try_read_bytes(
bitstring: &[u8],
ty: TypeId,
) -> Result<Box<dyn EvaluatedTerm<TLSProtocolTypes>>, puffin::error::Error> {
log::trace!("Trying read...");
try_read!(
bitstring,
ty,
Message,
OpaqueMessage,
MessageFlight,
OpaqueMessageFlight,
Vec<Certificate>,
Certificate,
CertificateEntries,
Vec<CertificateEntry>,
CertificateEntry,
ServerExtensions,
Vec<ServerExtension>,
ClientExtensions,
Vec<ClientExtension>,
ClientExtension,
ServerExtension,
HelloRetryExtensions,
Vec<HelloRetryExtension>,
HelloRetryExtension,
Vec<CertReqExtension>,
CertReqExtension,
Vec<CertificateExtension>,
CertificateExtension,
Vec<NewSessionTicketExtension>,
NewSessionTicketExtension,
NewSessionTicketExtensions,
Random,
Compressions,
Vec<Compression>,
Compression,
SessionID,
CipherSuites,
Vec<CipherSuite>,
CipherSuite,
Vec<PresharedKeyIdentity>,
PresharedKeyIdentity,
AlertMessagePayload,
SignatureScheme,
ProtocolVersion,
HandshakeHash,
u64,
u32,
PayloadU24,
PayloadU16,
PayloadU8,
Vec<PayloadU24>,
Vec<PayloadU16>,
Vec<PayloadU8>,
VecU16OfPayloadU16,
VecU16OfPayloadU8,
Vec<u8>,
Option<Vec<u8>>,
Vec<Vec<u8>>,
bool,
NamedGroup
)
}