DEVELOPMENT ENVIRONMENT

~liljamo/emerwen

ref: 1d405b318fcc8f3c552c8439782cc1addae5f84c emerwen/emerwen-protocol/src/lib.rs -rw-r--r-- 3.9 KiB
1d405b31Jonni Liljamo feat: initial 15 days ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
use tokio::io::{AsyncRead, AsyncReadExt};

#[derive(thiserror::Error, Debug)]
pub enum DecodeError {
    #[error("invalid version {0} (expected {1})")]
    InvalidVersion(u8, u8),
    #[error("invalid type {0}")]
    InvalidType(u8),
    #[error("invalid boolean {0}")]
    InvalidBoolean(u8),
}

const PROTOCOL_VERSION: u8 = 0;

/// Message
///
/// Byte representation:
/// | u8      | u8   | [u8]
/// | version | type | content bytes, structure varies per type
#[derive(Debug, PartialEq)]
pub enum Message {
    /// Master -> Worker authentication.
    Authentication { key: String },
    /// Master -> Worker target configuration.
    //ConfigureTarget { target_id: u16, method: u8, addr: String },
    /// Worker -> Master target state change.
    TargetStateChange {
        /// ID of the target.
        target_id: u16,
        /// State the target changed to.
        ///
        /// [`true`] means up.
        /// [`false`] means down.
        state: bool,
    },
}

impl From<&Message> for u8 {
    fn from(message: &Message) -> u8 {
        match message {
            Message::Authentication { .. } => 0,
            Message::TargetStateChange { .. } => 1,
        }
    }
}

impl Message {
    pub fn encode(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
        let mut buf: Vec<u8> = Vec::new();

        buf.push(PROTOCOL_VERSION);

        // Message type
        buf.push(self.into());

        match self {
            Message::Authentication { key } => {
                // Key length
                buf.extend((key.len() as u32).to_be_bytes());
                // Key
                buf.extend(key.as_bytes());
            }
            Message::TargetStateChange { target_id, state } => {
                buf.extend(target_id.to_be_bytes());
                buf.push((*state).into());
            }
        }

        Ok(buf)
    }

    pub async fn decode(
        buf: &mut (impl AsyncRead + std::marker::Unpin),
    ) -> Result<Message, Box<dyn std::error::Error + Send + Sync>> {
        let version = buf.read_u8().await?;
        if version != PROTOCOL_VERSION {
            return Err(Box::new(DecodeError::InvalidVersion(
                version,
                PROTOCOL_VERSION,
            )));
        }

        let mtype = buf.read_u8().await?;

        let message;
        match mtype {
            0 => {
                let key_length = buf.read_u32().await?;
                let mut key_bytes = vec![0; key_length as usize];
                buf.read_exact(&mut key_bytes).await?;
                message = Message::Authentication {
                    key: String::from_utf8(key_bytes)?,
                };
            }
            1 => {
                let target_id = buf.read_u16().await?;
                let state = u8_to_bool(buf.read_u8().await?)?;

                message = Message::TargetStateChange { target_id, state };
            }
            _ => return Err(Box::new(DecodeError::InvalidType(mtype))),
        }

        Ok(message)
    }
}

fn u8_to_bool(n: u8) -> Result<bool, DecodeError> {
    match n {
        0 => Ok(false),
        1 => Ok(true),
        _ => Err(DecodeError::InvalidBoolean(n)),
    }
}

#[tokio::test]
async fn test_round_trip_authentication() {
    let sent_message = Message::Authentication {
        key: "this_is_a_key".to_owned(),
    };

    let sent_message_bytes = sent_message.encode().unwrap();

    let mut reader = std::io::Cursor::new(sent_message_bytes);
    let received_message = Message::decode(&mut reader).await.unwrap();

    assert_eq!(sent_message, received_message);
}

#[tokio::test]
async fn test_round_trip_targetstatechange() {
    let sent_message = Message::TargetStateChange {
        target_id: 42,
        state: true,
    };

    let sent_message_bytes = sent_message.encode().unwrap();

    let mut reader = std::io::Cursor::new(sent_message_bytes);
    let received_message = Message::decode(&mut reader).await.unwrap();

    assert_eq!(sent_message, received_message);
}