Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tls_codec: feature for conditional deserialization derivation #1214

Merged
merged 18 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions tls_codec/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@ criterion = { version = "0.5", default-features = false }
regex = "1.8"

[features]
default = [ "std" ]
arbitrary = [ "std", "dep:arbitrary" ]
derive = [ "tls_codec_derive" ]
serde = [ "std", "dep:serde" ]
default = ["std"]
arbitrary = ["std", "dep:arbitrary"]
derive = ["tls_codec_derive"]
serde = ["std", "dep:serde"]
mls = [] # In MLS variable length vectors are limited compared to QUIC.
std = [ "tls_codec_derive?/std" ]
std = ["tls_codec_derive?/std"]
conditional_deserialization = [
"derive",
"tls_codec_derive/conditional_deserialization",
]

[[bench]]
name = "tls_vec"
Expand Down
5 changes: 3 additions & 2 deletions tls_codec/derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ rust-version = "1.60"
proc-macro = true

[dependencies]
syn = { version = "2", features = ["parsing"] }
syn = { version = "2", features = ["parsing", "full"] }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we always need full? If not, let's only enable it for the conditional_deserialization feature.

quote = "1.0"
proc-macro2 = "1.0"

Expand All @@ -23,5 +23,6 @@ tls_codec = { path = "../" }
trybuild = "1"

[features]
default = [ "std" ]
default = ["std"]
conditional_deserialization = []
std = []
145 changes: 145 additions & 0 deletions tls_codec/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,36 @@
//! c: u8,
//! }
//! ```
//!
//! ## Conditional deserialization via the `conditionally_deserializable` attribute macro
//!
//! In some cases, it can be useful to have two variants of a struct, where one
//! is deserializable and one isn't. For example, the deserializable variant of
//! the struct could represent an unverified message, where only verification
//! produces the verified variant. Further processing could then be restricted
//! to the undeserializable struct variant.
//!
//! A pattern like this can be created via the `conditionally_deserializable`
//! attribute macro (requires the `conditional_deserialization` feature flag).
//!
//! The macro adds a boolean const generic to the struct and creates two
//! aliases, one for the deserializable variant (with a "`Deserializable`"
//! prefix) and one for the undeserializable one (with an "`Undeserializable`"
//! prefix).
//!
//! ```
//! #[conditionally_deserializable]
//! #[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
//! struct ExampleStruct {
//! a: u8,
//! b: u16,
//! }
//!
//! let undeserializable_struct = UndeserializableExampleStruct { a: 1, b: 2 };
//! let serialized = undeserializable_struct.tls_serialize_detached().unwrap();
//! let deserializable_struct =
//! DeserializableExampleStruct::tls_deserialize(&mut serialized.as_slice()).unwrap();
//! ```

extern crate proc_macro;
extern crate proc_macro2;
Expand All @@ -176,6 +206,9 @@ use syn::{
Expr, ExprLit, ExprPath, Field, Generics, Ident, Lit, Member, Meta, Result, Token, Type,
};

#[cfg(feature = "conditional_deserialization")]
use syn::{parse_quote, ConstParam, ImplGenerics, ItemStruct, TypeGenerics};

/// Attribute name to identify attributes to be processed by derive-macros in this crate.
const ATTR_IDENT: &str = "tls_codec";

Expand Down Expand Up @@ -895,6 +928,27 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr
}
}

#[cfg(feature = "conditional_deserialization")]
fn restrict_conditional_generic(
impl_generics: ImplGenerics,
ty_generics: TypeGenerics,
deserializable: bool,
) -> (TokenStream2, TokenStream2) {
let impl_generics = quote! { #impl_generics }
.to_string()
.replace(" const IS_DESERIALIZABLE : bool ", "")
.replace("<>", "")
.parse()
.unwrap();
let deserializable_string = if deserializable { "true" } else { "false" };
let ty_generics = quote! { #ty_generics }
.to_string()
.replace("IS_DESERIALIZABLE", deserializable_string)
.parse()
.unwrap();
(impl_generics, ty_generics)
}

#[allow(unused_variables)]
fn impl_deserialize(parsed_ast: TlsStruct) -> TokenStream2 {
match parsed_ast {
Expand All @@ -914,6 +968,9 @@ fn impl_deserialize(parsed_ast: TlsStruct) -> TokenStream2 {
.map(|p| p.for_trait("Deserialize"))
.collect::<Vec<_>>();
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
#[cfg(feature = "conditional_deserialization")]
let (impl_generics, ty_generics) =
restrict_conditional_generic(impl_generics, ty_generics, true);
quote! {
impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause {
#[cfg(feature = "std")]
Expand Down Expand Up @@ -1003,6 +1060,9 @@ fn impl_deserialize_bytes(parsed_ast: TlsStruct) -> TokenStream2 {
.map(|p| p.for_trait("DeserializeBytes"))
.collect::<Vec<_>>();
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
#[cfg(feature = "conditional_deserialization")]
let (impl_generics, ty_generics) =
restrict_conditional_generic(impl_generics, ty_generics, true);
quote! {
impl #impl_generics tls_codec::DeserializeBytes for #ident #ty_generics #where_clause {
fn tls_deserialize_bytes(bytes: &[u8]) -> core::result::Result<(Self, &[u8]), tls_codec::Error> {
Expand Down Expand Up @@ -1102,3 +1162,88 @@ fn partition_skipped(
(members_skip, member_prefixes_skip),
)
}

/// The `conditionally_deserializable` attribute macro takes as input either
/// `Bytes` or `Reader` and does the following:
/// * Add a boolean const generic to the struct indicating if the variant of the
/// struct is deserializable or not.
/// * Depending on the input derive either the `TlsDeserialize` or
/// `TlsDeserializeBytes` trait for the deserializable variant
/// * Create type aliases for the deserializable and undeserializable variant of
/// the struct, where the alias is the name of the struct prefixed with
/// `Deserializable` or `Undeserializable` respectively.
///
/// The `conditionally_deserializable` attribute macro is only available if the
/// `conditional_deserialization` feature is enabled.
///
#[cfg_attr(
feature = "conditional_deserialization",
doc = r##"
```compile_fail
use tls_codec_derive::{TlsSerialize, TlsDeserialize, TlsSize, conditionally_deserializable};

#[conditionally_deserializable(Bytes)]
#[derive(TlsDeserialize, TlsSerialize, TlsSize)]
struct ExampleStruct {
pub a: u16,
}

impl UndeserializableExampleStruct {
#[cfg(feature = "conditional_deserialization")]
fn deserialize(bytes: &[u8]) -> Result<Self, tls_codec::Error> {
Self::tls_deserialize_exact(bytes)
}
}
```
"##
)]
#[cfg(feature = "conditional_deserialization")]
#[proc_macro_attribute]
pub fn conditionally_deserializable(
_input: TokenStream,
annotated_item: TokenStream,
) -> TokenStream {
let annotated_item = parse_macro_input!(annotated_item as ItemStruct);
impl_conditionally_deserializable(annotated_item).into()
}

#[cfg(feature = "conditional_deserialization")]
fn impl_conditionally_deserializable(mut annotated_item: ItemStruct) -> TokenStream2 {
let deserializable_const_generic: ConstParam = parse_quote! {const IS_DESERIALIZABLE: bool};
// Add the DESERIALIZABLE const generic to the struct
annotated_item
.generics
.params
.push(deserializable_const_generic.into());
// Derive either TlsDeserialize or TlsDeserializeBytes depending on the input
let deserialize_bytes_implementation =
impl_deserialize_bytes(parse_ast(annotated_item.clone().into()).unwrap());
let deserialize_implementation =
impl_deserialize(parse_ast(annotated_item.clone().into()).unwrap());
let (impl_generics, ty_generics, _) = annotated_item.generics.split_for_impl();
let (_deserializable_impl_generics, deserializable_ty_generics) =
restrict_conditional_generic(impl_generics.clone(), ty_generics.clone(), true);
let (_undeserializable_impl_generics, undeserializable_ty_generics) =
restrict_conditional_generic(impl_generics.clone(), ty_generics.clone(), false);
let annotated_item_ident = annotated_item.ident.clone();
let deserializable_ident = Ident::new(
&format!("Deserializable{}", annotated_item_ident),
Span::call_site(),
);
let undeserializable_ident = Ident::new(
&format!("Undeserializable{}", annotated_item_ident),
Span::call_site(),
);
let annotated_item_visibility = annotated_item.vis.clone();
// For now, we assume that the struct doesn't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's something missing here?

quote! {
#annotated_item

#annotated_item_visibility type #undeserializable_ident = #annotated_item_ident #undeserializable_ty_generics;
#annotated_item_visibility type #deserializable_ident = #annotated_item_ident #deserializable_ty_generics;

#deserialize_implementation

#deserialize_bytes_implementation
}
}
46 changes: 46 additions & 0 deletions tls_codec/derive/tests/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,49 @@ fn type_with_unknowns() {
let deserialized = TypeWithUnknowns::tls_deserialize_exact(incoming);
assert!(matches!(deserialized, Err(Error::UnknownValue(3))));
}

#[cfg(feature = "conditional_deserialization")]
mod conditional_deserialization {
use tls_codec::{Deserialize, Serialize};
use tls_codec_derive::{conditionally_deserializable, TlsSerialize, TlsSize};

#[test]
fn conditionally_deserializable_struct() {
#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct ExampleStruct {
a: u8,
b: u16,
}

let undeserializable_struct = UndeserializableExampleStruct { a: 1, b: 2 };
let serialized = undeserializable_struct.tls_serialize_detached().unwrap();
let deserializable_struct =
DeserializableExampleStruct::tls_deserialize(&mut serialized.as_slice()).unwrap();
assert_eq!(deserializable_struct.a, undeserializable_struct.a);
assert_eq!(deserializable_struct.b, undeserializable_struct.b);

#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct SecondExampleStruct {
a: u8,
b: u16,
}
}

#[test]
fn conditional_deserializable_struct_bytes() {
#[conditionally_deserializable]
#[derive(TlsSize, TlsSerialize, PartialEq, Debug)]
struct ExampleStruct {
a: u8,
b: u16,
}
let undeserializable_struct = UndeserializableExampleStruct { a: 1, b: 2 };
let serialized = undeserializable_struct.tls_serialize_detached().unwrap();
let deserializable_struct =
DeserializableExampleStruct::tls_deserialize_exact(&mut &*serialized).unwrap();
assert_eq!(deserializable_struct.a, undeserializable_struct.a);
assert_eq!(deserializable_struct.b, undeserializable_struct.b);
}
}