use std::error::Error as StdError;
use std::fmt;
use std::time::SystemTimeError;
use crate::tls::rustls::msgs::enums::{AlertDescription, ContentType, HandshakeType};
use crate::tls::rustls::rand;
#[derive(Debug, PartialEq, Clone)]
pub struct CorruptMessagePayload {
location: &'static core::panic::Location<'static>,
kind: ContentType,
}
impl CorruptMessagePayload {
pub fn content_type(&self) -> ContentType {
self.kind
}
}
#[derive(Debug, PartialEq, Clone)]
pub enum Error {
InappropriateMessage {
expect_types: Vec<ContentType>,
got_type: ContentType,
},
InappropriateHandshakeMessage {
expect_types: Vec<HandshakeType>,
got_type: HandshakeType,
},
CorruptMessage,
CorruptMessagePayload(CorruptMessagePayload),
NoCertificatesPresented,
UnsupportedNameType,
DecryptError,
EncryptError,
PeerIncompatibleError(String),
PeerMisbehavedError(String),
AlertReceived(AlertDescription),
InvalidCertificateEncoding,
InvalidCertificateSignatureType,
InvalidCertificateSignature,
InvalidCertificateData(String),
InvalidSct(sct::Error),
General(String),
FailedToGetCurrentTime,
FailedToGetRandomBytes,
HandshakeNotComplete,
PeerSentOversizedRecord,
NoApplicationProtocol,
BadMaxFragmentSize,
}
impl Error {
#[track_caller]
pub fn corrupt_message(kind: ContentType) -> Self {
Self::CorruptMessagePayload(CorruptMessagePayload {
location: core::panic::Location::caller(),
kind,
})
}
}
fn join<T: fmt::Debug>(items: &[T]) -> String {
items
.iter()
.map(|x| format!("{:?}", x))
.collect::<Vec<String>>()
.join(" or ")
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::InappropriateMessage {
ref expect_types,
ref got_type,
} => write!(
f,
"received unexpected message: got {:?} when expecting {}",
got_type,
join::<ContentType>(expect_types)
),
Self::InappropriateHandshakeMessage {
ref expect_types,
ref got_type,
} => write!(
f,
"received unexpected handshake message: got {:?} when expecting {}",
got_type,
join::<HandshakeType>(expect_types)
),
Self::CorruptMessagePayload(ref typ) => {
write!(f, "received corrupt message of type {:?}", typ)
}
Self::PeerIncompatibleError(ref why) => write!(f, "peer is incompatible: {}", why),
Self::PeerMisbehavedError(ref why) => write!(f, "peer misbehaved: {}", why),
Self::AlertReceived(ref alert) => write!(f, "received fatal alert: {:?}", alert),
Self::InvalidCertificateEncoding => {
write!(f, "invalid peer certificate encoding")
}
Self::InvalidCertificateSignatureType => {
write!(f, "invalid peer certificate signature type")
}
Self::InvalidCertificateSignature => {
write!(f, "invalid peer certificate signature")
}
Self::InvalidCertificateData(ref reason) => {
write!(f, "invalid peer certificate contents: {}", reason)
}
Self::CorruptMessage => write!(f, "received corrupt message"),
Self::NoCertificatesPresented => write!(f, "peer sent no certificates"),
Self::UnsupportedNameType => write!(f, "presented server name type wasn't supported"),
Self::DecryptError => write!(f, "cannot decrypt peer's message"),
Self::EncryptError => write!(f, "cannot encrypt message"),
Self::PeerSentOversizedRecord => write!(f, "peer sent excess record size"),
Self::HandshakeNotComplete => write!(f, "handshake not complete"),
Self::NoApplicationProtocol => write!(f, "peer doesn't support any known protocol"),
Self::InvalidSct(ref err) => write!(f, "invalid certificate timestamp: {:?}", err),
Self::FailedToGetCurrentTime => write!(f, "failed to get current time"),
Self::FailedToGetRandomBytes => write!(f, "failed to get random bytes"),
Self::BadMaxFragmentSize => {
write!(f, "the supplied max_fragment_size was too small or large")
}
Self::General(ref err) => write!(f, "unexpected error: {}", err),
}
}
}
impl From<SystemTimeError> for Error {
#[inline]
fn from(_: SystemTimeError) -> Self {
Self::FailedToGetCurrentTime
}
}
impl StdError for Error {}
impl From<rand::GetRandomFailed> for Error {
fn from(_: rand::GetRandomFailed) -> Self {
Self::FailedToGetRandomBytes
}
}
#[cfg(test)]
mod tests {
use super::{CorruptMessagePayload, Error};
#[test_log::test]
fn smoke() {
use sct;
use crate::tls::rustls::msgs::enums::{AlertDescription, ContentType, HandshakeType};
let _all = vec![
Error::InappropriateMessage {
expect_types: vec![ContentType::Alert],
got_type: ContentType::Handshake,
},
Error::InappropriateHandshakeMessage {
expect_types: vec![HandshakeType::ClientHello, HandshakeType::Finished],
got_type: HandshakeType::ServerHello,
},
Error::CorruptMessage,
Error::CorruptMessagePayload(CorruptMessagePayload {
location: core::panic::Location::caller(),
kind: ContentType::Alert,
}),
Error::NoCertificatesPresented,
Error::DecryptError,
Error::PeerIncompatibleError("no tls1.2".to_string()),
Error::PeerMisbehavedError("inconsistent something".to_string()),
Error::AlertReceived(AlertDescription::ExportRestriction),
Error::InvalidCertificateEncoding,
Error::InvalidCertificateSignatureType,
Error::InvalidCertificateSignature,
Error::InvalidCertificateData("Data".into()),
Error::InvalidSct(sct::Error::MalformedSct),
Error::General("undocumented error".to_string()),
Error::FailedToGetCurrentTime,
Error::FailedToGetRandomBytes,
Error::HandshakeNotComplete,
Error::PeerSentOversizedRecord,
Error::NoApplicationProtocol,
Error::BadMaxFragmentSize,
];
}
#[test_log::test]
fn rand_error_mapping() {
use super::rand;
let err: Error = rand::GetRandomFailed.into();
assert_eq!(err, Error::FailedToGetRandomBytes);
}
#[test_log::test]
fn time_error_mapping() {
use std::time::SystemTime;
let time_error = SystemTime::UNIX_EPOCH
.duration_since(SystemTime::now())
.unwrap_err();
let err: Error = time_error.into();
assert_eq!(err, Error::FailedToGetCurrentTime);
}
}