use super::{DecodeError, PROTOCOL_VERSION};
pub const HEADER_LENGTH: usize = 6;
/// Message Header
///
/// Byte representation:
/// | u8 | u8 | [u8; 4] / u32 |
/// | protocol_version | payload_type | payload_length |
#[derive(Debug, PartialEq)]
pub struct MessageHeader {
protocol_version: u8,
pub payload_type: u8,
pub payload_length: u32,
}
impl MessageHeader {
pub fn new(payload_type: u8, payload_length: u32) -> MessageHeader {
MessageHeader {
protocol_version: PROTOCOL_VERSION,
payload_type,
payload_length,
}
}
pub fn encode(&self) -> [u8; HEADER_LENGTH] {
let length_bytes = self.payload_length.to_be_bytes();
[
self.protocol_version,
self.payload_type,
length_bytes[0],
length_bytes[1],
length_bytes[2],
length_bytes[3],
]
}
pub fn decode(bytes: [u8; HEADER_LENGTH]) -> Result<MessageHeader, Box<dyn std::error::Error>> {
let protocol_version = bytes[0];
if protocol_version != PROTOCOL_VERSION {
return Err(Box::new(DecodeError::InvalidVersion(
protocol_version,
PROTOCOL_VERSION,
)));
}
let payload_type = bytes[1];
// TODO: check?
let payload_length = u32::from_be_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]);
Ok(MessageHeader {
protocol_version,
payload_type,
payload_length,
})
}
}
#[test]
fn test_round_trip_header() {
let sent_header = MessageHeader::new(6, 255);
let sent_header_bytes = sent_header.encode();
let received_header = MessageHeader::decode(sent_header_bytes).unwrap();
assert_eq!(sent_header, received_header);
}