diff --git a/data_model/derive/Cargo.toml b/data_model/derive/Cargo.toml index aec550c0668..bea78f7492f 100644 --- a/data_model/derive/Cargo.toml +++ b/data_model/derive/Cargo.toml @@ -11,7 +11,7 @@ license.workspace = true proc-macro = true [dependencies] -syn2 = { workspace = true, features = ["default", "full", "extra-traits"] } +syn2 = { workspace = true, features = ["default", "full", "extra-traits", "visit-mut"] } quote = { workspace = true } darling = { workspace = true } proc-macro2 = { workspace = true } diff --git a/data_model/derive/src/partially_tagged.rs b/data_model/derive/src/partially_tagged/mod.rs similarity index 92% rename from data_model/derive/src/partially_tagged.rs rename to data_model/derive/src/partially_tagged/mod.rs index f446f3e1b91..a17fd95aeb0 100644 --- a/data_model/derive/src/partially_tagged.rs +++ b/data_model/derive/src/partially_tagged/mod.rs @@ -2,6 +2,8 @@ // darling-generated code triggers this lint #![allow(clippy::option_if_let_else)] +mod resolve_self; + use darling::{FromDeriveInput, FromVariant}; use manyhow::Result; use proc_macro2::TokenStream; @@ -40,25 +42,36 @@ impl PartiallyTaggedEnum { fn untagged_variants(&self) -> impl Iterator { self.variants().filter(|variant| variant.untagged) } + + /// Returns a type that corresponds to `Self`, handling the generics as necessary + fn self_ty(&self) -> syn2::Type { + let ident = &self.ident; + let (_, type_generics, _) = self.generics.split_for_impl(); + + parse_quote!(#ident #type_generics) + } } impl PartiallyTaggedVariant { - fn ty(&self) -> &syn2::Type { - self.fields.fields.first().expect( + fn ty(&self, self_ty: &syn2::Type) -> syn2::Type { + let ty = self.fields.fields.first().expect( "BUG: Only newtype enums are supported. Enforced by `darling(supports(enum_newtype))`", - ) + ).clone(); + + resolve_self::resolve_self(self_ty, ty) } } /// Convert from vector of variants to tuple of vectors consisting of variant's fields fn variants_to_tuple<'lt, I: Iterator>( + self_ty: &syn2::Type, variants: I, -) -> (Vec<&'lt Ident>, Vec<&'lt Type>, Vec<&'lt [Attribute]>) { +) -> (Vec<&'lt Ident>, Vec, Vec<&'lt [Attribute]>) { variants.fold( (Vec::new(), Vec::new(), Vec::new()), |(mut idents, mut types, mut attrs), variant| { idents.push(&variant.ident); - types.push(&variant.ty()); + types.push(variant.ty(self_ty)); attrs.push(&variant.attrs); (idents, types, attrs) }, @@ -72,9 +85,11 @@ pub fn impl_partially_tagged_serialize(input: &syn2::DeriveInput) -> Result Result { + self_ty: &'a syn2::Type, +} + +impl VisitMut for Visitor<'_> { + fn visit_type_mut(&mut self, ty: &mut syn2::Type) { + match ty { + syn2::Type::Path(path_ty) + if path_ty.qself.is_none() && path_ty.path.is_ident("Self") => + { + *ty = self.self_ty.clone(); + } + _ => syn2::visit_mut::visit_type_mut(self, ty), + } + } +} + +/// Transforms the [`resolving_ty`] by replacing `Self` with [`self_ty`]. +/// +/// This is required to be able to use `Self` in `PartiallyTaggedSerialize` and `PartiallyTaggedDeserialize`, +/// as they define an additional intermediate type during serialization/deserialization. Using `Self` there would refer to an incorrect type. +pub fn resolve_self(self_ty: &syn2::Type, mut resolving_ty: syn2::Type) -> syn2::Type { + Visitor { self_ty }.visit_type_mut(&mut resolving_ty); + resolving_ty +} + +#[cfg(test)] +mod tests { + use quote::ToTokens; + use syn2::{parse_quote, Type}; + + #[test] + fn test_resolve_self() { + let test_types = [ + parse_quote!(i32), + parse_quote!(Self), + parse_quote!(Vec), + parse_quote!((Self, Self)), + parse_quote!(::Type), + ]; + let expected_types = [ + parse_quote!(i32), + parse_quote!(()), + parse_quote!(Vec<()>), + parse_quote!(((), ())), + parse_quote!(<() as Trait>::Type), + ]; + let _: &Type = &test_types[0]; + let _: &Type = &expected_types[0]; + + for (test_type, expected_type) in test_types.iter().zip(expected_types.iter()) { + let resolved = super::resolve_self(&parse_quote!(()), test_type.clone()); + assert_eq!( + resolved, + *expected_type, + "Failed to resolve `Self` in `{}`", + test_type.to_token_stream().to_string() + ); + } + } +} diff --git a/data_model/derive/tests/partial_tagged_serde_self.rs b/data_model/derive/tests/partial_tagged_serde_self.rs new file mode 100644 index 00000000000..e4520e6ee03 --- /dev/null +++ b/data_model/derive/tests/partial_tagged_serde_self.rs @@ -0,0 +1,39 @@ +//! A test for `PartiallyTaggedSerialize` and `PartiallyTaggedDeserialize` which uses `Self` as a type + +use iroha_data_model_derive::{PartiallyTaggedDeserialize, PartiallyTaggedSerialize}; + +#[derive(Debug, PartialEq, Eq, PartiallyTaggedSerialize, PartiallyTaggedDeserialize)] +enum Expr { + Negate(Box), + #[serde_partially_tagged(untagged)] + Atom(T), +} + +#[test] +fn partially_tagged_serde() { + use Expr::*; + + let values = [ + Atom(42), + Negate(Box::new(Atom(42))), + Negate(Box::new(Negate(Box::new(Atom(42))))), + ]; + let serialized_values = [r#"42"#, r#"{"Negate":42}"#, r#"{"Negate":{"Negate":42}}"#]; + + for (value, serialized_value) in values.iter().zip(serialized_values.iter()) { + let serialized = serde_json::to_string(value) + .unwrap_or_else(|e| panic!("Failed to serialize `{:?}`: {:?}", value, e)); + assert_eq!( + serialized, *serialized_value, + "Serialized form of `{:?}` does not match the expected value", + value + ); + let deserialized: Expr = serde_json::from_str(serialized_value) + .unwrap_or_else(|e| panic!("Failed to deserialize `{:?}`: {:?}", serialized_value, e)); + assert_eq!( + *value, deserialized, + "Deserialized form of `{:?}` does not match the expected value", + value + ); + } +} diff --git a/data_model/src/predicate.rs b/data_model/src/predicate.rs index b619681f9a1..5581e8063f3 100644 --- a/data_model/src/predicate.rs +++ b/data_model/src/predicate.rs @@ -91,11 +91,11 @@ macro_rules! nontrivial { // references (e.g. &Value). pub enum GenericPredicateBox

{ /// Logically `&&` the results of applying the predicates. - And(NonTrivial>), + And(NonTrivial), /// Logically `||` the results of applying the predicats. - Or(NonTrivial>), + Or(NonTrivial), /// Negate the result of applying the predicate. - Not(Box>), + Not(Box), /// The raw predicate that must be applied. #[serde_partially_tagged(untagged)] Raw(P),