From ef2b830165b22e9bcc235221875b81e359a8bb7d Mon Sep 17 00:00:00 2001 From: Jonni Liljamo Date: Fri, 8 Nov 2024 13:50:49 +0200 Subject: [PATCH] feat: 2nd protocol iteration --- Cargo.lock | 14 +++ Cargo.toml | 7 +- emerwen-master/Cargo.toml | 1 + emerwen-protocol/Cargo.toml | 3 + emerwen-protocol/src/header.rs | 68 ++++++++++++++ emerwen-protocol/src/lib.rs | 152 ++++++++------------------------ emerwen-protocol/src/payload.rs | 124 ++++++++++++++++++++++++++ emerwen-types/Cargo.toml | 11 +++ emerwen-types/src/lib.rs | 5 ++ emerwen-worker/Cargo.toml | 1 + 10 files changed, 272 insertions(+), 114 deletions(-) create mode 100644 emerwen-protocol/src/header.rs create mode 100644 emerwen-protocol/src/payload.rs create mode 100644 emerwen-types/Cargo.toml create mode 100644 emerwen-types/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index ae607f4..2fe9e45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -102,6 +102,12 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.8.0" @@ -166,6 +172,7 @@ version = "0.1.0" dependencies = [ "clap", "emerwen-protocol", + "emerwen-types", "tokio", "tracing", "tracing-subscriber", @@ -175,16 +182,23 @@ dependencies = [ name = "emerwen-protocol" version = "0.1.0" dependencies = [ + "byteorder", + "emerwen-types", "thiserror", "tokio", ] +[[package]] +name = "emerwen-types" +version = "0.1.0" + [[package]] name = "emerwen-worker" version = "0.1.0" dependencies = [ "clap", "emerwen-protocol", + "emerwen-types", "tokio", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index a700736..ea0c9fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,11 @@ [workspace] resolver = "2" -members = ["emerwen-master", "emerwen-protocol", "emerwen-worker"] +members = [ + "emerwen-master", + "emerwen-protocol", + "emerwen-types", + "emerwen-worker", +] [workspace.package] authors = ["Jonni Liljamo MessageHeader { + MessageHeader { + protocol_version: PROTOCOL_VERSION, + payload_type, + payload_length, + } + } + + pub fn encode(&self) -> [u8; HEADER_LENGTH] { + let length_bytes = self.payload_length.to_be_bytes(); + [ + self.protocol_version, + self.payload_type, + length_bytes[0], + length_bytes[1], + length_bytes[2], + length_bytes[3], + ] + } + + pub fn decode(bytes: [u8; HEADER_LENGTH]) -> Result> { + let protocol_version = bytes[0]; + if protocol_version != PROTOCOL_VERSION { + return Err(Box::new(DecodeError::InvalidVersion( + protocol_version, + PROTOCOL_VERSION, + ))); + } + + let payload_type = bytes[1]; + // TODO: check? + + let payload_length = u32::from_be_bytes([bytes[2], bytes[3], bytes[4], bytes[5]]); + + Ok(MessageHeader { + protocol_version, + payload_type, + payload_length, + }) + } +} + +#[test] +fn test_round_trip_header() { + let sent_header = MessageHeader::new(6, 255); + let sent_header_bytes = sent_header.encode(); + + let received_header = MessageHeader::decode(sent_header_bytes).unwrap(); + + assert_eq!(sent_header, received_header); +} diff --git a/emerwen-protocol/src/lib.rs b/emerwen-protocol/src/lib.rs index 9d34c1f..c91b15a 100644 --- a/emerwen-protocol/src/lib.rs +++ b/emerwen-protocol/src/lib.rs @@ -1,143 +1,69 @@ -use tokio::io::{AsyncRead, AsyncReadExt}; +mod header; +use header::{MessageHeader, HEADER_LENGTH}; + +mod payload; +use payload::MessagePayload; + +pub const PROTOCOL_VERSION: u8 = 0; #[derive(thiserror::Error, Debug)] pub enum DecodeError { #[error("invalid version {0} (expected {1})")] InvalidVersion(u8, u8), - #[error("invalid type {0}")] - InvalidType(u8), + #[error("invalid paylaod type {0}")] + InvalidPayloadType(u8), #[error("invalid boolean {0}")] InvalidBoolean(u8), } -const PROTOCOL_VERSION: u8 = 0; - -/// Message -/// -/// Byte representation: -/// | u8 | u8 | [u8] -/// | version | type | content bytes, structure varies per type #[derive(Debug, PartialEq)] -pub enum Message { - /// Master -> Worker authentication. - Authentication { key: String }, - /// Master -> Worker target configuration. - //ConfigureTarget { target_id: u16, method: u8, addr: String }, - /// Worker -> Master target state change. - TargetStateChange { - /// ID of the target. - target_id: u16, - /// State the target changed to. - /// - /// [`true`] means up. - /// [`false`] means down. - state: bool, - }, -} - -impl From<&Message> for u8 { - fn from(message: &Message) -> u8 { - match message { - Message::Authentication { .. } => 0, - Message::TargetStateChange { .. } => 1, - } - } +pub struct Message { + //header: MessageHeader, + payload: MessagePayload, } impl Message { - pub fn encode(&self) -> Result, Box> { - let mut buf: Vec = Vec::new(); - - buf.push(PROTOCOL_VERSION); - - // Message type - buf.push(self.into()); - - match self { - Message::Authentication { key } => { - // Key length - buf.extend((key.len() as u32).to_be_bytes()); - // Key - buf.extend(key.as_bytes()); - } - Message::TargetStateChange { target_id, state } => { - buf.extend(target_id.to_be_bytes()); - buf.push((*state).into()); - } - } - - Ok(buf) + pub fn new(payload: MessagePayload) -> Message { + Message { payload } } - pub async fn decode( - buf: &mut (impl AsyncRead + std::marker::Unpin), - ) -> Result> { - let version = buf.read_u8().await?; - if version != PROTOCOL_VERSION { - return Err(Box::new(DecodeError::InvalidVersion( - version, - PROTOCOL_VERSION, - ))); - } + pub fn encode(&self) -> Vec { + let payload_type = (&self.payload).into(); + let payload_bytes = self.payload.encode(); - let mtype = buf.read_u8().await?; + let header = MessageHeader::new(payload_type, payload_bytes.len() as u32); + let header_bytes = header.encode(); - let message; - match mtype { - 0 => { - let key_length = buf.read_u32().await?; - let mut key_bytes = vec![0; key_length as usize]; - buf.read_exact(&mut key_bytes).await?; - message = Message::Authentication { - key: String::from_utf8(key_bytes)?, - }; - } - 1 => { - let target_id = buf.read_u16().await?; - let state = u8_to_bool(buf.read_u8().await?)?; + let mut buf = vec![]; + buf.extend(header_bytes); + buf.extend(payload_bytes); + buf + } - message = Message::TargetStateChange { target_id, state }; - } - _ => return Err(Box::new(DecodeError::InvalidType(mtype))), - } + pub fn decode(buf: &mut impl std::io::Read) -> Result> { + let mut header_bytes = [0; HEADER_LENGTH]; + buf.read_exact(&mut header_bytes)?; + let header = MessageHeader::decode(header_bytes)?; - Ok(message) - } -} + let mut payload_bytes = vec![0; header.payload_length as usize]; + buf.read_exact(&mut payload_bytes)?; + let mut payload_reader = std::io::Cursor::new(payload_bytes); + let payload = MessagePayload::decode(header.payload_type, &mut payload_reader)?; -fn u8_to_bool(n: u8) -> Result { - match n { - 0 => Ok(false), - 1 => Ok(true), - _ => Err(DecodeError::InvalidBoolean(n)), + Ok(Message { payload }) } } -#[tokio::test] -async fn test_round_trip_authentication() { - let sent_message = Message::Authentication { +#[test] +fn test_round_trip_message_authentication() { + let sent_message = Message::new(MessagePayload::Authentication { key: "this_is_a_key".to_owned(), - }; - - let sent_message_bytes = sent_message.encode().unwrap(); - - let mut reader = std::io::Cursor::new(sent_message_bytes); - let received_message = Message::decode(&mut reader).await.unwrap(); - - assert_eq!(sent_message, received_message); -} - -#[tokio::test] -async fn test_round_trip_targetstatechange() { - let sent_message = Message::TargetStateChange { - target_id: 42, - state: true, - }; + }); - let sent_message_bytes = sent_message.encode().unwrap(); + let sent_message_bytes = sent_message.encode(); let mut reader = std::io::Cursor::new(sent_message_bytes); - let received_message = Message::decode(&mut reader).await.unwrap(); + let received_message = Message::decode(&mut reader).unwrap(); assert_eq!(sent_message, received_message); } diff --git a/emerwen-protocol/src/payload.rs b/emerwen-protocol/src/payload.rs new file mode 100644 index 0000000..376b214 --- /dev/null +++ b/emerwen-protocol/src/payload.rs @@ -0,0 +1,124 @@ +use byteorder::{BigEndian, ReadBytesExt}; + +use crate::DecodeError; + +/// Message Payload +/// +/// Byte representation: +/// | [u8] | +/// | content bytes, structure varies per type | +#[derive(Debug, PartialEq)] +pub enum MessagePayload { + /// Master -> Worker authentication. + Authentication { key: String }, + /// Master -> Worker target configuration. + /*ConfigureTarget { + target_id: u16, + method: TargetMethod, + addr: String + },*/ + /// Worker -> Master target state change. + TargetStateChange { + /// ID of the target. + target_id: u16, + /// State the target changed to. + /// + /// [`true`] means up. + /// [`false`] means down. + state: bool, + }, +} + +impl From<&MessagePayload> for u8 { + fn from(message: &MessagePayload) -> u8 { + match message { + MessagePayload::Authentication { .. } => 0, + MessagePayload::TargetStateChange { .. } => 1, + } + } +} + +impl MessagePayload { + pub fn encode(&self) -> Vec { + let mut buf: Vec = Vec::new(); + + match self { + MessagePayload::Authentication { key } => { + // Key length + buf.extend((key.len() as u32).to_be_bytes()); + // Key + buf.extend(key.as_bytes()); + } + MessagePayload::TargetStateChange { target_id, state } => { + buf.extend(target_id.to_be_bytes()); + buf.push((*state).into()); + } + } + + buf + } + + pub fn decode( + payload_type: u8, + buf: &mut impl std::io::Read, + ) -> Result> { + let payload = match payload_type { + 0 => { + let key_length = buf.read_u32::()?; + let mut key_bytes = vec![0; key_length as usize]; + buf.read_exact(&mut key_bytes)?; + MessagePayload::Authentication { + key: String::from_utf8(key_bytes)?, + } + } + 1 => { + let target_id = buf.read_u16::()?; + let state = u8_to_bool(buf.read_u8()?)?; + + MessagePayload::TargetStateChange { target_id, state } + } + _ => return Err(Box::new(DecodeError::InvalidPayloadType(payload_type))), + }; + + Ok(payload) + } +} + +fn u8_to_bool(n: u8) -> Result { + match n { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(DecodeError::InvalidBoolean(n)), + } +} + +#[test] +fn test_round_trip_payload_authentication() { + let sent_payload = MessagePayload::Authentication { + key: "this_is_a_key".to_owned(), + }; + + let sent_payload_bytes = sent_payload.encode(); + let sent_payload_type = (&sent_payload).into(); + + let mut reader = std::io::Cursor::new(sent_payload_bytes); + let received_payload = MessagePayload::decode(sent_payload_type, &mut reader).unwrap(); + + assert_eq!(sent_payload, received_payload); +} + +#[test] +fn test_round_trip_payload_targetstatechange() { + let sent_payload = MessagePayload::TargetStateChange { + target_id: 42, + state: true, + }; + + let sent_payload_bytes = sent_payload.encode(); + let sent_payload_type = (&sent_payload).into(); + + let mut reader = std::io::Cursor::new(sent_payload_bytes); + let received_payload = MessagePayload::decode(sent_payload_type, &mut reader).unwrap(); + + assert_eq!(sent_payload, received_payload); +} diff --git a/emerwen-types/Cargo.toml b/emerwen-types/Cargo.toml new file mode 100644 index 0000000..4e10a9f --- /dev/null +++ b/emerwen-types/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "emerwen-types" +version = "0.1.0" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +publish.workspace = true +repository.workspace = true + +[dependencies] diff --git a/emerwen-types/src/lib.rs b/emerwen-types/src/lib.rs new file mode 100644 index 0000000..18c8435 --- /dev/null +++ b/emerwen-types/src/lib.rs @@ -0,0 +1,5 @@ +#[derive(Debug, PartialEq)] +pub enum TargetMethod { + Ping, + Get(u8), +} diff --git a/emerwen-worker/Cargo.toml b/emerwen-worker/Cargo.toml index 5a3c126..6caa137 100644 --- a/emerwen-worker/Cargo.toml +++ b/emerwen-worker/Cargo.toml @@ -10,6 +10,7 @@ repository.workspace = true [dependencies] emerwen-protocol = { path = "../emerwen-protocol" } +emerwen-types = { path = "../emerwen-types" } clap = { version = "4", features = ["derive"] } tokio = { version = "1", features = ["full"] } -- 2.44.1