From 7a1af34c54d33f346d7138f65cf45126a8ef7d55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=AD=90=EF=B8=8FNINIKA=E2=AD=90=EF=B8=8F?= Date: Mon, 7 Oct 2024 17:58:51 +0300 Subject: [PATCH] refactor(schema_derive): use a smarter algorithm for auto-generated trait bounds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ⭐️NINIKA⭐️ --- Cargo.toml | 6 +- crates/iroha_data_model/src/isi.rs | 11 - crates/iroha_schema_derive/Cargo.toml | 2 +- crates/iroha_schema_derive/src/lib.rs | 44 ++- .../iroha_schema_derive/src/trait_bounds.rs | 262 ++++++++++++++++++ .../tests/ui_pass/derive_into_schema.rs | 1 - 6 files changed, 300 insertions(+), 26 deletions(-) create mode 100644 crates/iroha_schema_derive/src/trait_bounds.rs diff --git a/Cargo.toml b/Cargo.toml index f1e82af249e..fdca347237c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -128,9 +128,9 @@ mv = { version = "0.1.0" } [workspace.lints] rustdoc.private_doc_tests = "deny" -rust.future_incompatible = {level = "deny", priority = -1 } -rust.nonstandard_style = {level = "deny", priority = -1 } -rust.rust_2018_idioms = {level = "deny", priority = -1 } +rust.future_incompatible = { level = "deny", priority = -1 } +rust.nonstandard_style = { level = "deny", priority = -1 } +rust.rust_2018_idioms = { level = "deny", priority = -1 } rust.unused = { level = "deny", priority = -1 } rust.anonymous_parameters = "deny" diff --git a/crates/iroha_data_model/src/isi.rs b/crates/iroha_data_model/src/isi.rs index 1d6a7e16f0a..53688a21bd4 100644 --- a/crates/iroha_data_model/src/isi.rs +++ b/crates/iroha_data_model/src/isi.rs @@ -264,7 +264,6 @@ mod transparent { isi! { /// Generic instruction to set key value at the object. - #[schema(bounds = "O: Identifiable, O::Id: IntoSchema")] pub struct SetKeyValue { /// Where to set key value. pub object: O::Id, @@ -356,7 +355,6 @@ mod transparent { isi! { /// Generic instruction to remove key value at the object. - #[schema(bounds = "O: Identifiable, O::Id: IntoSchema")] pub struct RemoveKeyValue { /// From where to remove key value. pub object: O::Id, @@ -437,7 +435,6 @@ mod transparent { isi! { /// Generic instruction for a registration of an object to the identifiable destination. - #[schema(bounds = "O: Registered, O::With: IntoSchema")] #[serde(transparent)] pub struct Register { /// The object that should be registered, should be uniquely identifiable by its id. @@ -524,7 +521,6 @@ mod transparent { isi! { /// Generic instruction for an unregistration of an object from the identifiable destination. - #[schema(bounds = "O: Identifiable, O::Id: IntoSchema")] pub struct Unregister { /// [`Identifiable::Id`] of the object which should be unregistered. pub object: O::Id, @@ -606,7 +602,6 @@ mod transparent { isi! { /// Generic instruction for a mint of an object to the identifiable destination. - #[schema(bounds = "O: IntoSchema, D: Identifiable, D::Id: IntoSchema")] pub struct Mint { /// Object which should be minted. pub object: O, @@ -656,7 +651,6 @@ mod transparent { isi! { /// Generic instruction for a burn of an object to the identifiable destination. - #[schema(bounds = "O: IntoSchema, D: Identifiable, D::Id: IntoSchema")] pub struct Burn { /// Object which should be burned. pub object: O, @@ -706,9 +700,6 @@ mod transparent { isi! { /// Generic instruction for a transfer of an object from the identifiable source to the identifiable destination. - #[schema(bounds = "S: Identifiable, S::Id: IntoSchema, \ - O: IntoSchema, \ - D: Identifiable, D::Id: IntoSchema")] pub struct Transfer { /// Source object `Id`. pub source: S::Id, @@ -802,7 +793,6 @@ mod transparent { isi! { /// Generic instruction for granting permission to an entity. - #[schema(bounds = "O: IntoSchema, D: Identifiable, D::Id: IntoSchema")] pub struct Grant { /// Object to grant. pub object: O, @@ -863,7 +853,6 @@ mod transparent { isi! { /// Generic instruction for revoking permission from an entity. - #[schema(bounds = "O: IntoSchema, D: Identifiable, D::Id: IntoSchema")] pub struct Revoke { /// Object to revoke. pub object: O, diff --git a/crates/iroha_schema_derive/Cargo.toml b/crates/iroha_schema_derive/Cargo.toml index 5a0b667d973..b18394d033d 100644 --- a/crates/iroha_schema_derive/Cargo.toml +++ b/crates/iroha_schema_derive/Cargo.toml @@ -16,7 +16,7 @@ proc-macro = true [dependencies] iroha_macro_utils = { path = "../iroha_macro_utils" } -syn = { workspace = true, features = ["default", "full"] } +syn = { workspace = true, features = ["default", "full", "visit"] } proc-macro2 = { workspace = true } quote = { workspace = true } manyhow = { workspace = true, features = ["darling"] } diff --git a/crates/iroha_schema_derive/src/lib.rs b/crates/iroha_schema_derive/src/lib.rs index b1c8b40753f..e312821ea63 100644 --- a/crates/iroha_schema_derive/src/lib.rs +++ b/crates/iroha_schema_derive/src/lib.rs @@ -3,6 +3,8 @@ // darling-generated code triggers this lint #![allow(clippy::option_if_let_else)] +mod trait_bounds; + use darling::{ast::Style, FromAttributes, FromDeriveInput, FromField, FromMeta, FromVariant}; use iroha_macro_utils::Emitter; use manyhow::{emit, error_message, manyhow, Result}; @@ -10,6 +12,19 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens}; use syn::parse_quote; +fn add_bounds_to_all_generic_parameters(generics: &mut syn::Generics, bound: syn::Path) { + let generic_type_parameters = generics + .type_params() + .map(|ty_param| ty_param.ident.clone()) + .collect::>(); + if !generic_type_parameters.is_empty() { + let where_clause = generics.make_where_clause(); + for ty in generic_type_parameters { + where_clause.predicates.push(parse_quote!(#ty: #bound)); + } + } +} + fn override_where_clause( emitter: &mut Emitter, where_clause: Option<&syn::WhereClause>, @@ -33,11 +48,9 @@ pub fn type_id_derive(input: TokenStream) -> Result { fn impl_type_id(input: &mut syn::DeriveInput) -> TokenStream { let name = &input.ident; - input.generics.type_params_mut().for_each(|ty_param| { - ty_param - .bounds - .push(syn::parse_quote! {iroha_schema::TypeId}); - }); + // Unlike IntoSchema, `TypeId` bounds are required only on the generic type parameters, as in the standard "dumb" algorithm + // The schema of the fields are irrelevant here, as we only need the names of the parameters + add_bounds_to_all_generic_parameters(&mut input.generics, parse_quote!(iroha_schema::TypeId)); let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let type_id_body = trait_body(name, &input.generics, true); @@ -201,11 +214,22 @@ pub fn schema_derive(input: TokenStream) -> TokenStream { return emitter.finish_token_stream(); }; - input.generics.type_params_mut().for_each(|ty_param| { - ty_param - .bounds - .push(parse_quote! {iroha_schema::IntoSchema}); - }); + // first of all, `IntoSchema` impls are required for all generic type parameters to be able to call `type_name` on them + add_bounds_to_all_generic_parameters( + &mut input.generics, + parse_quote!(iroha_schema::IntoSchema), + ); + + // add trait bounds on field types using the same algorithm that parity scale codec uses + emitter.handle(trait_bounds::add( + &input.ident, + &mut input.generics, + &input.data, + syn::parse_quote!(iroha_schema::IntoSchema), + None, + false, + &syn::parse_quote!(iroha_schema), + )); let impl_type_id = impl_type_id(&mut syn::parse2(original_input).unwrap()); diff --git a/crates/iroha_schema_derive/src/trait_bounds.rs b/crates/iroha_schema_derive/src/trait_bounds.rs new file mode 100644 index 00000000000..fe9548568a1 --- /dev/null +++ b/crates/iroha_schema_derive/src/trait_bounds.rs @@ -0,0 +1,262 @@ +//! Algorithm for generating trait bounds in IntoSchema derive +//! +//! Based on https://github.com/paritytech/parity-scale-codec/blob/2c61d4ab70dfa157556430546441cd2deb5031f2/derive/src/trait_bounds.rs + +use std::iter; + +use proc_macro2::Ident; +use syn::{ + parse_quote, + visit::{self, Visit}, + Generics, Result, Type, TypePath, +}; + +use crate::{IntoSchemaData, IntoSchemaField}; + +/// Visits the ast and checks if one of the given idents is found. +struct ContainIdents<'a> { + result: bool, + idents: &'a [Ident], +} + +impl<'ast> Visit<'ast> for ContainIdents<'_> { + fn visit_ident(&mut self, i: &'ast Ident) { + if self.idents.iter().any(|id| id == i) { + self.result = true; + } + } +} + +/// Checks if the given type contains one of the given idents. +fn type_contain_idents(ty: &Type, idents: &[Ident]) -> bool { + let mut visitor = ContainIdents { + result: false, + idents, + }; + visitor.visit_type(ty); + visitor.result +} + +/// Visits the ast and checks if the a type path starts with the given ident. +struct TypePathStartsWithIdent<'a> { + result: bool, + ident: &'a Ident, +} + +impl<'ast> Visit<'ast> for TypePathStartsWithIdent<'_> { + fn visit_type_path(&mut self, i: &'ast TypePath) { + if let Some(segment) = i.path.segments.first() { + if &segment.ident == self.ident { + self.result = true; + return; + } + } + + visit::visit_type_path(self, i); + } +} + +/// Checks if the given type path or any containing type path starts with the given ident. +fn type_path_or_sub_starts_with_ident(ty: &TypePath, ident: &Ident) -> bool { + let mut visitor = TypePathStartsWithIdent { + result: false, + ident, + }; + visitor.visit_type_path(ty); + visitor.result +} + +/// Checks if the given type or any containing type path starts with the given ident. +fn type_or_sub_type_path_starts_with_ident(ty: &Type, ident: &Ident) -> bool { + let mut visitor = TypePathStartsWithIdent { + result: false, + ident, + }; + visitor.visit_type(ty); + visitor.result +} + +/// Visits the ast and collects all type paths that do not start or contain the given ident. +/// +/// Returns `T`, `N`, `A` for `Vec<(Recursive, A)>` with `Recursive` as ident. +struct FindTypePathsNotStartOrContainIdent<'a> { + result: Vec, + ident: &'a Ident, +} + +impl<'ast> Visit<'ast> for FindTypePathsNotStartOrContainIdent<'_> { + fn visit_type_path(&mut self, i: &'ast TypePath) { + if type_path_or_sub_starts_with_ident(i, self.ident) { + visit::visit_type_path(self, i); + } else { + self.result.push(i.clone()); + } + } +} + +/// Collects all type paths that do not start or contain the given ident in the given type. +/// +/// Returns `T`, `N`, `A` for `Vec<(Recursive, A)>` with `Recursive` as ident. +fn find_type_paths_not_start_or_contain_ident(ty: &Type, ident: &Ident) -> Vec { + let mut visitor = FindTypePathsNotStartOrContainIdent { + result: Vec::new(), + ident, + }; + visitor.visit_type(ty); + visitor.result +} + +#[allow(clippy::too_many_arguments)] +/// Add required trait bounds to all generic types. +/// +/// This adds types of all the fields of the struct or enum that use a generic parameter to the where clause, with the following exceptions: +/// +/// - If the field is marked as `#[codec(skip)]`, a different bound or no bound at all is added (based on the value of `codec_skip_bound` parameter). +/// - If the field mentions the input type itself, no bound is added. Heuristics are used, so this might not work in all cases. +/// - If the field is marked as `#[codec(compact)]`, the type `Compact` is used instead of `T`. +pub fn add( + input_ident: &Ident, + generics: &mut Generics, + data: &IntoSchemaData, + // custom_trait_bound: Option>, + codec_bound: syn::Path, + codec_skip_bound: Option, + dumb_trait_bounds: bool, + crate_path: &syn::Path, +) -> Result<()> { + let skip_type_params = Vec::::new(); + // NOTE: not implementing custom trait bounds for now + // can be implemented later if needed + // = match custom_trait_bound { + // Some(CustomTraitBound::SpecifiedBounds { bounds, .. }) => { + // generics.make_where_clause().predicates.extend(bounds); + // return Ok(()); + // } + // Some(CustomTraitBound::SkipTypeParams { type_names, .. }) => { + // type_names.into_iter().collect::>() + // } + // None => Vec::new(), + // }; + + let ty_params = generics + .type_params() + .filter(|tp| skip_type_params.iter().all(|skip| skip != &tp.ident)) + .map(|tp| tp.ident.clone()) + .collect::>(); + if ty_params.is_empty() { + return Ok(()); + } + + let codec_types = + get_types_to_add_trait_bound(input_ident, data, &ty_params, dumb_trait_bounds)?; + + let compact_types = collect_types(data, |t| t.codec_attrs.compact)? + .into_iter() + // Only add a bound if the type uses a generic + .filter(|ty| type_contain_idents(ty, &ty_params)) + .collect::>(); + + let skip_types = if codec_skip_bound.is_some() { + let needs_default_bound = |f: &IntoSchemaField| f.codec_attrs.skip; + collect_types(data, needs_default_bound)? + .into_iter() + // Only add a bound if the type uses a generic + .filter(|ty| type_contain_idents(ty, &ty_params)) + .collect::>() + } else { + Vec::new() + }; + + if !codec_types.is_empty() || !compact_types.is_empty() || !skip_types.is_empty() { + let where_clause = generics.make_where_clause(); + + codec_types.into_iter().for_each(|ty| { + where_clause + .predicates + .push(parse_quote!(#ty : #codec_bound)) + }); + + compact_types.into_iter().for_each(|ty| { + where_clause + .predicates + .push(parse_quote!(#crate_path::Compact<#ty> : #codec_bound)) + }); + + skip_types.into_iter().for_each(|ty| { + let codec_skip_bound = codec_skip_bound.as_ref(); + where_clause + .predicates + .push(parse_quote!(#ty : #codec_skip_bound)) + }); + } + + Ok(()) +} + +/// Returns all types that must be added to the where clause with the respective trait bound. +fn get_types_to_add_trait_bound( + input_ident: &Ident, + data: &IntoSchemaData, + ty_params: &[Ident], + dumb_trait_bound: bool, +) -> Result> { + if dumb_trait_bound { + Ok(ty_params.iter().map(|t| parse_quote!( #t )).collect()) + } else { + let needs_codec_bound = |f: &IntoSchemaField| { + !f.codec_attrs.compact + && true // utils::get_encoded_as_type(f).is_none() + && !f.codec_attrs.skip + }; + let res = collect_types(data, needs_codec_bound)? + .into_iter() + // Only add a bound if the type uses a generic + .filter(|ty| type_contain_idents(ty, ty_params)) + // If a struct contains itself as field type, we can not add this type into the where + // clause. This is required to work a round the following compiler bug: https://github.com/rust-lang/rust/issues/47032 + .flat_map(|ty| { + find_type_paths_not_start_or_contain_ident(&ty, input_ident) + .into_iter() + .map(Type::Path) + // Remove again types that do not contain any of our generic parameters + .filter(|ty| type_contain_idents(ty, ty_params)) + // Add back the original type, as we don't want to loose it. + .chain(iter::once(ty)) + }) + // Remove all remaining types that start/contain the input ident to not have them in the + // where clause. + .filter(|ty| !type_or_sub_type_path_starts_with_ident(ty, input_ident)) + .collect(); + + Ok(res) + } +} + +fn collect_types( + data: &IntoSchemaData, + type_filter: fn(&IntoSchemaField) -> bool, +) -> Result> { + let types = match *data { + IntoSchemaData::Struct(ref data) => data + .fields + .iter() + .filter(|f| type_filter(f)) + .map(|f| f.ty.clone()) + .collect(), + + IntoSchemaData::Enum(ref variants) => variants + .iter() + .filter(|variant| !variant.codec_attrs.skip) + .flat_map(|variant| { + variant + .fields + .iter() + .filter(|f| type_filter(f)) + .map(|f| f.ty.clone()) + .collect::>() + }) + .collect(), + }; + + Ok(types) +} diff --git a/crates/iroha_schema_derive/tests/ui_pass/derive_into_schema.rs b/crates/iroha_schema_derive/tests/ui_pass/derive_into_schema.rs index f75295a3f82..fe8455a223f 100644 --- a/crates/iroha_schema_derive/tests/ui_pass/derive_into_schema.rs +++ b/crates/iroha_schema_derive/tests/ui_pass/derive_into_schema.rs @@ -47,7 +47,6 @@ pub trait Trait { } #[derive(IntoSchema)] -#[schema(bounds = "T: Trait, T::Assoc: IntoSchema")] pub struct WithComplexGeneric { _value: T::Assoc, }