From f78362008b4b5522181c77e830f585c9d9e8b799 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Thu, 14 Mar 2024 22:19:21 +0900 Subject: [PATCH] gAdded a magic number to the message serializing. (#141) The point here is to make sure we ignore message coming from a different program, or coming from a version of chitchat that was not properly versioned. In that case, the enum discriminant could coincide with our version number. --- chitchat/src/message.rs | 55 ++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/chitchat/src/message.rs b/chitchat/src/message.rs index 3dded42..4f05f55 100644 --- a/chitchat/src/message.rs +++ b/chitchat/src/message.rs @@ -6,6 +6,8 @@ use crate::delta::Delta; use crate::digest::Digest; use crate::serialize::{Deserializable, Serializable}; +const MAGIC_NUMBER: u16 = 45_139; + /// Chitchat message. /// /// Each variant represents a step of the gossip "handshake" @@ -72,6 +74,7 @@ impl MessageType { impl Serializable for ChitchatMessage { fn serialize(&self, buf: &mut Vec) { + buf.extend(MAGIC_NUMBER.to_le_bytes()); ProtocolVersion::V0.to_code().serialize(buf); match self { @@ -96,26 +99,31 @@ impl Serializable for ChitchatMessage { } fn serialized_len(&self) -> usize { - 1 + match self { - ChitchatMessage::Syn { cluster_id, digest } => { - 1 + cluster_id.serialized_len() + digest.serialized_len() + 2 + 1 + + match self { + ChitchatMessage::Syn { cluster_id, digest } => { + 1 + cluster_id.serialized_len() + digest.serialized_len() + } + ChitchatMessage::SynAck { digest, delta } => { + 1 + digest.serialized_len() + delta.serialized_len() + } + ChitchatMessage::Ack { delta } => 1 + delta.serialized_len(), + ChitchatMessage::BadCluster => 1, } - ChitchatMessage::SynAck { digest, delta } => { - 1 + digest.serialized_len() + delta.serialized_len() - } - ChitchatMessage::Ack { delta } => 1 + delta.serialized_len(), - ChitchatMessage::BadCluster => 1, - } } } impl Deserializable for ChitchatMessage { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let protocol_version = buf - .first() - .copied() - .and_then(ProtocolVersion::from_code) - .context("invalid protocol version")?; + if buf.len() < 3 { + bail!("buffer too small to store the magic number and the protocol version"); + } + let magic_number = u16::from_le_bytes(buf[0..2].try_into().unwrap()); + if magic_number != MAGIC_NUMBER { + bail!("invalid chitchat magic number"); + } + let protocol_version = + ProtocolVersion::from_code(buf[2]).context("invalid protocol version")?; if protocol_version != ProtocolVersion::V0 { bail!( @@ -123,7 +131,7 @@ impl Deserializable for ChitchatMessage { protocol_version.to_code() ) } - buf.consume(1); + buf.consume(3); let message_type = buf .first() @@ -164,7 +172,7 @@ mod tests { cluster_id: "cluster-a".to_string(), digest: Digest::default(), }; - test_serdeser_aux(&syn, 15); + test_serdeser_aux(&syn, 17); } { let mut digest = Digest::default(); @@ -175,7 +183,7 @@ mod tests { cluster_id: "cluster-a".to_string(), digest, }; - test_serdeser_aux(&syn, 66); + test_serdeser_aux(&syn, 68); } } @@ -186,8 +194,9 @@ mod tests { digest: Digest::default(), delta: Delta::default(), }; - // 1 (protocol version) + 1 (message tag) + 2 (digest len) + 1 (delta end op) - test_serdeser_aux(&syn_ack, 5); + // 2 (magic number) + 1 (protocol version) + 1 (message tag) + 2 (digest len) + 1 (delta + // end op) + test_serdeser_aux(&syn_ack, 7); } { // 2 bytes. @@ -212,7 +221,7 @@ mod tests { let syn_ack = ChitchatMessage::SynAck { digest, delta }; // 1 byte (protocol version) + 1 byte (message tag) + 53 bytes (digest) + 60 bytes // (delta). - test_serdeser_aux(&syn_ack, 1 + 1 + 53 + 60); + test_serdeser_aux(&syn_ack, 2 + 1 + 1 + 53 + 60); } } @@ -221,7 +230,7 @@ mod tests { { let delta = Delta::default(); let ack = ChitchatMessage::Ack { delta }; - test_serdeser_aux(&ack, 3); + test_serdeser_aux(&ack, 5); } { // 4 bytes. @@ -233,12 +242,12 @@ mod tests { delta.add_kv(&node, "key", "value", 0, true); delta.set_serialized_len(60); let ack = ChitchatMessage::Ack { delta }; - test_serdeser_aux(&ack, 1 + 1 + 60); + test_serdeser_aux(&ack, 2 + 1 + 1 + 60); } } #[test] fn test_bad_cluster() { - test_serdeser_aux(&ChitchatMessage::BadCluster, 2); + test_serdeser_aux(&ChitchatMessage::BadCluster, 4); } }