use tokio::io::{AsyncRead, AsyncReadExt};
#[derive(thiserror::Error, Debug)]
pub enum DecodeError {
#[error("invalid version {0} (expected {1})")]
InvalidVersion(u8, u8),
#[error("invalid type {0}")]
InvalidType(u8),
#[error("invalid boolean {0}")]
InvalidBoolean(u8),
}
const PROTOCOL_VERSION: u8 = 0;
/// Message
///
/// Byte representation:
/// | u8 | u8 | [u8]
/// | version | type | content bytes, structure varies per type
#[derive(Debug, PartialEq)]
pub enum Message {
/// Master -> Worker authentication.
Authentication { key: String },
/// Master -> Worker target configuration.
//ConfigureTarget { target_id: u16, method: u8, addr: String },
/// Worker -> Master target state change.
TargetStateChange {
/// ID of the target.
target_id: u16,
/// State the target changed to.
///
/// [`true`] means up.
/// [`false`] means down.
state: bool,
},
}
impl From<&Message> for u8 {
fn from(message: &Message) -> u8 {
match message {
Message::Authentication { .. } => 0,
Message::TargetStateChange { .. } => 1,
}
}
}
impl Message {
pub fn encode(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let mut buf: Vec<u8> = Vec::new();
buf.push(PROTOCOL_VERSION);
// Message type
buf.push(self.into());
match self {
Message::Authentication { key } => {
// Key length
buf.extend((key.len() as u32).to_be_bytes());
// Key
buf.extend(key.as_bytes());
}
Message::TargetStateChange { target_id, state } => {
buf.extend(target_id.to_be_bytes());
buf.push((*state).into());
}
}
Ok(buf)
}
pub async fn decode(
buf: &mut (impl AsyncRead + std::marker::Unpin),
) -> Result<Message, Box<dyn std::error::Error + Send + Sync>> {
let version = buf.read_u8().await?;
if version != PROTOCOL_VERSION {
return Err(Box::new(DecodeError::InvalidVersion(
version,
PROTOCOL_VERSION,
)));
}
let mtype = buf.read_u8().await?;
let message;
match mtype {
0 => {
let key_length = buf.read_u32().await?;
let mut key_bytes = vec![0; key_length as usize];
buf.read_exact(&mut key_bytes).await?;
message = Message::Authentication {
key: String::from_utf8(key_bytes)?,
};
}
1 => {
let target_id = buf.read_u16().await?;
let state = u8_to_bool(buf.read_u8().await?)?;
message = Message::TargetStateChange { target_id, state };
}
_ => return Err(Box::new(DecodeError::InvalidType(mtype))),
}
Ok(message)
}
}
fn u8_to_bool(n: u8) -> Result<bool, DecodeError> {
match n {
0 => Ok(false),
1 => Ok(true),
_ => Err(DecodeError::InvalidBoolean(n)),
}
}
#[tokio::test]
async fn test_round_trip_authentication() {
let sent_message = Message::Authentication {
key: "this_is_a_key".to_owned(),
};
let sent_message_bytes = sent_message.encode().unwrap();
let mut reader = std::io::Cursor::new(sent_message_bytes);
let received_message = Message::decode(&mut reader).await.unwrap();
assert_eq!(sent_message, received_message);
}
#[tokio::test]
async fn test_round_trip_targetstatechange() {
let sent_message = Message::TargetStateChange {
target_id: 42,
state: true,
};
let sent_message_bytes = sent_message.encode().unwrap();
let mut reader = std::io::Cursor::new(sent_message_bytes);
let received_message = Message::decode(&mut reader).await.unwrap();
assert_eq!(sent_message, received_message);
}