mod header;
use header::{MessageHeader, HEADER_LENGTH};
mod payload;
use payload::MessagePayload;
pub const PROTOCOL_VERSION: u8 = 0;
#[derive(thiserror::Error, Debug)]
pub enum DecodeError {
#[error("invalid version {0} (expected {1})")]
InvalidVersion(u8, u8),
#[error("invalid paylaod type {0}")]
InvalidPayloadType(u8),
#[error("invalid boolean {0}")]
InvalidBoolean(u8),
}
#[derive(Debug, PartialEq)]
pub struct Message {
//header: MessageHeader,
payload: MessagePayload,
}
impl Message {
pub fn new(payload: MessagePayload) -> Message {
Message { payload }
}
pub fn encode(&self) -> Vec<u8> {
let payload_type = (&self.payload).into();
let payload_bytes = self.payload.encode();
let header = MessageHeader::new(payload_type, payload_bytes.len() as u32);
let header_bytes = header.encode();
let mut buf = vec![];
buf.extend(header_bytes);
buf.extend(payload_bytes);
buf
}
pub fn decode(buf: &mut impl std::io::Read) -> Result<Message, Box<dyn std::error::Error>> {
let mut header_bytes = [0; HEADER_LENGTH];
buf.read_exact(&mut header_bytes)?;
let header = MessageHeader::decode(header_bytes)?;
let mut payload_bytes = vec![0; header.payload_length as usize];
buf.read_exact(&mut payload_bytes)?;
let mut payload_reader = std::io::Cursor::new(payload_bytes);
let payload = MessagePayload::decode(header.payload_type, &mut payload_reader)?;
Ok(Message { payload })
}
}
#[test]
fn test_round_trip_message_authentication() {
let sent_message = Message::new(MessagePayload::Authentication {
key: "this_is_a_key".to_owned(),
});
let sent_message_bytes = sent_message.encode();
let mut reader = std::io::Cursor::new(sent_message_bytes);
let received_message = Message::decode(&mut reader).unwrap();
assert_eq!(sent_message, received_message);
}