diff --git a/plonky2/src/hash/hash_types.rs b/plonky2/src/hash/hash_types.rs index 992f36197e..c7fb397d30 100644 --- a/plonky2/src/hash/hash_types.rs +++ b/plonky2/src/hash/hash_types.rs @@ -1,7 +1,9 @@ #[cfg(not(feature = "std"))] use alloc::vec::Vec; +use core::fmt; use anyhow::ensure; +use serde::de::{self, Visitor}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::field::goldilocks_field::GoldilocksField; @@ -193,19 +195,48 @@ impl GenericHashOut for BytesHash { } impl Serialize for BytesHash { - fn serialize(&self, _serializer: S) -> Result + fn serialize(&self, serializer: S) -> Result where S: Serializer, { - todo!() + serializer.serialize_bytes(&self.0) + } +} + +struct ByteHashVisitor; + +impl<'de, const N: usize> Visitor<'de> for ByteHashVisitor { + type Value = BytesHash; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "an array containing exactly {} bytes", N) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let mut bytes = [0u8; N]; + for i in 0..N { + bytes[i] = seq.next_element().unwrap().unwrap(); + } + Ok(BytesHash(bytes)) + } + + fn visit_bytes(self, s: &[u8]) -> Result + where + E: de::Error, + { + let bytes = s.try_into().unwrap(); + Ok(BytesHash(bytes)) } } impl<'de, const N: usize> Deserialize<'de> for BytesHash { - fn deserialize(_deserializer: D) -> Result + fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { - todo!() + deserializer.deserialize_seq(ByteHashVisitor::) } }