Skip to content

Commit

Permalink
attribute macro: both deserialization variants
Browse files Browse the repository at this point in the history
  • Loading branch information
kkohbrok committed Nov 14, 2023
1 parent d0a38a9 commit 0cd330b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
33 changes: 14 additions & 19 deletions tls_codec/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,13 @@
//! A pattern like this can be created via the `conditionally_deserializable`
//! attribute macro (requires the `conditional_deserialization` feature flag).
//!
//! The macro takes a single argument, which is the name of the deserialize
//! trait variant (either `Reader` or `Bytes`) that should be derived for the
//! deserializable struct part.
//!
//! The macro will then add a boolean const generic to the struct and create two
//! aliases, one for the deserializable variant and one for the undeserializable
//! one.
//! The macro adds a boolean const generic to the struct and creates two
//! aliases, one for the deserializable variant (with a "`Deserializable`"
//! prefix) and one for the undeserializable one (with an "`Undeserializable`"
//! prefix).
//!
//! ```
//! #[conditionally_deserializable(Reader)]
//! #[conditionally_deserializable]
//! #[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
//! struct ExampleStruct {
//! a: u8,
Expand Down Expand Up @@ -1203,30 +1200,26 @@ impl UndeserializableExampleStruct {
#[cfg(feature = "conditional_deserialization")]
#[proc_macro_attribute]
pub fn conditionally_deserializable(
input: TokenStream,
_input: TokenStream,
annotated_item: TokenStream,
) -> TokenStream {
let annotated_item = parse_macro_input!(annotated_item as ItemStruct);
let input = parse_macro_input!(input as Ident);
impl_conditionally_deserializable(annotated_item, input).into()
impl_conditionally_deserializable(annotated_item).into()
}

#[cfg(feature = "conditional_deserialization")]
fn impl_conditionally_deserializable(mut annotated_item: ItemStruct, input: Ident) -> TokenStream2 {
fn impl_conditionally_deserializable(mut annotated_item: ItemStruct) -> TokenStream2 {
let deserializable_const_generic: ConstParam = parse_quote! {const IS_DESERIALIZABLE: bool};
// Add the DESERIALIZABLE const generic to the struct
annotated_item
.generics
.params
.push(deserializable_const_generic.into());
// Derive either TlsDeserialize or TlsDeserializeBytes depending on the input
let deserialize_implementation = if input == Ident::new("Bytes", Span::call_site()) {
impl_deserialize_bytes(parse_ast(annotated_item.clone().into()).unwrap())
} else if input == Ident::new("Reader", Span::call_site()) {
impl_deserialize(parse_ast(annotated_item.clone().into()).unwrap())
} else {
panic!("verifiable attribute macro only supports \"Bytes\" and \"Reader\" as input");
};
let deserialize_bytes_implementation =
impl_deserialize_bytes(parse_ast(annotated_item.clone().into()).unwrap());
let deserialize_implementation =
impl_deserialize(parse_ast(annotated_item.clone().into()).unwrap());
let (impl_generics, ty_generics, _) = annotated_item.generics.split_for_impl();
let (_deserializable_impl_generics, deserializable_ty_generics) =
restrict_conditional_generic(impl_generics.clone(), ty_generics.clone(), true);
Expand All @@ -1250,5 +1243,7 @@ fn impl_conditionally_deserializable(mut annotated_item: ItemStruct, input: Iden
#annotated_item_visibility type #deserializable_ident = #annotated_item_ident #deserializable_ty_generics;

#deserialize_implementation

#deserialize_bytes_implementation
}
}
6 changes: 3 additions & 3 deletions tls_codec/derive/tests/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ mod conditional_deserialization {

#[test]
fn conditionally_deserializable_struct() {
#[conditionally_deserializable(Reader)]
#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct ExampleStruct {
a: u8,
Expand All @@ -549,7 +549,7 @@ mod conditional_deserialization {
assert_eq!(deserializable_struct.a, undeserializable_struct.a);
assert_eq!(deserializable_struct.b, undeserializable_struct.b);

#[conditionally_deserializable(Reader)]
#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct SecondExampleStruct {
a: u8,
Expand All @@ -559,7 +559,7 @@ mod conditional_deserialization {

#[test]
fn conditional_deserializable_struct_bytes() {
#[conditionally_deserializable(Bytes)]
#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct ExampleStruct {
a: u8,
Expand Down

0 comments on commit 0cd330b

Please sign in to comment.