DEVELOPMENT ENVIRONMENT

~liljamo/emerwen

ref: initial-tcp-experimentation emerwen/emerwen-protocol/src/message/header.rs -rw-r--r-- 1.8 KiB
1225af8eJonni Liljamo feat: more experimentation 13 days ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
use super::{DecodeError, PROTOCOL_VERSION};

pub const HEADER_LENGTH: usize = 6;

/// Message Header
///
/// Byte representation:
/// | u8               | u8           | [u8; 4] / u32  |
/// | protocol_version | payload_type | payload_length |
#[derive(Debug, PartialEq)]
pub struct MessageHeader {
    protocol_version: u8,
    pub payload_type: u8,
    pub payload_length: u32,
}

impl MessageHeader {
    pub fn new(payload_type: u8, payload_length: u32) -> MessageHeader {
        MessageHeader {
            protocol_version: PROTOCOL_VERSION,
            payload_type,
            payload_length,
        }
    }

    pub fn encode(&self) -> [u8; HEADER_LENGTH] {
        let length_bytes = self.payload_length.to_be_bytes();
        [
            self.protocol_version,
            self.payload_type,
            length_bytes[0],
            length_bytes[1],
            length_bytes[2],
            length_bytes[3],
        ]
    }

    pub fn decode(bytes: [u8; HEADER_LENGTH]) -> Result<MessageHeader, Box<dyn std::error::Error>> {
        let protocol_version = bytes[0];
        if protocol_version != PROTOCOL_VERSION {
            return Err(Box::new(DecodeError::InvalidVersion(
                protocol_version,
                PROTOCOL_VERSION,
            )));
        }

        let payload_type = bytes[1];
        // TODO: check?

        let payload_length = u32::from_be_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]);

        Ok(MessageHeader {
            protocol_version,
            payload_type,
            payload_length,
        })
    }
}

#[test]
fn test_round_trip_header() {
    let sent_header = MessageHeader::new(6, 255);
    let sent_header_bytes = sent_header.encode();

    let received_header = MessageHeader::decode(sent_header_bytes).unwrap();

    assert_eq!(sent_header, received_header);
}