diff --git a/data_model/derive/src/id.rs b/data_model/derive/src/id.rs index ad57dfd789e..18af318dcf7 100644 --- a/data_model/derive/src/id.rs +++ b/data_model/derive/src/id.rs @@ -1,16 +1,109 @@ #![allow(clippy::str_to_string, clippy::mixed_read_write_in_expression)] -use manyhow::{bail, Result}; +use darling::{FromAttributes, FromDeriveInput, FromField}; +use iroha_macro_utils::Emitter; +use manyhow::emit; use proc_macro2::TokenStream; -use quote::quote; +use quote::{quote, ToTokens}; use syn2::parse_quote; -pub fn impl_id(input: &syn2::ItemStruct) -> Result { +mod kw { + syn2::custom_keyword!(transparent); +} + +enum IdAttr { + Missing, + Normal, + Transparent, +} + +impl FromAttributes for IdAttr { + fn from_attributes(attrs: &[syn2::Attribute]) -> darling::Result { + let mut accumulator = darling::error::Accumulator::default(); + let attrs = attrs + .iter() + .filter(|v| v.path().is_ident("id")) + .collect::>(); + let attr = match attrs.as_slice() { + [] => { + return accumulator.finish_with(IdAttr::Missing); + } + [attr] => attr, + [attr, ref tail @ ..] => { + accumulator.push( + darling::Error::custom("Only one `#[id]` attribute is allowed!").with_span( + &tail + .iter() + .map(syn2::spanned::Spanned::span) + .reduce(|a, b| a.join(b).unwrap()) + .unwrap(), + ), + ); + attr + } + }; + + let result = match &attr.meta { + syn2::Meta::Path(_) => IdAttr::Normal, + syn2::Meta::List(list) if list.parse_args::().is_ok() => { + IdAttr::Transparent + } + _ => { + accumulator.push( + darling::Error::custom("Expected `#[id]` or `#[id(transparent)]`") + .with_span(&attr), + ); + IdAttr::Normal + } + }; + + accumulator.finish_with(result) + } +} + +#[derive(FromDeriveInput)] +#[darling(supports(struct_any))] +struct IdDeriveInput { + ident: syn2::Ident, + generics: syn2::Generics, + data: darling::ast::Data, +} + +struct IdField { + ident: Option, + ty: syn2::Type, + id_attr: IdAttr, +} + +impl FromField for IdField { + fn from_field(field: &syn2::Field) -> darling::Result { + let ident = field.ident.clone(); + let ty = field.ty.clone(); + let id_attr = IdAttr::from_attributes(&field.attrs)?; + + Ok(Self { ident, ty, id_attr }) + } +} + +impl IdDeriveInput { + fn fields(&self) -> &darling::ast::Fields { + match &self.data { + darling::ast::Data::Struct(fields) => fields, + _ => unreachable!(), + } + } +} + +pub fn impl_id_eq_ord_hash(emitter: &mut Emitter, input: &syn2::DeriveInput) -> TokenStream { + let Some(input) = emitter.handle(IdDeriveInput::from_derive_input(input)) else { + return quote!(); + }; + let name = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let identifiable_derive = derive_identifiable(input)?; + let identifiable_derive = derive_identifiable(emitter, &input); - Ok(quote! { + quote! { #identifiable_derive impl #impl_generics ::core::cmp::PartialOrd for #name #ty_generics #where_clause where Self: Identifiable { @@ -38,15 +131,15 @@ pub fn impl_id(input: &syn2::ItemStruct) -> Result { self.id().hash(state); } } - }) + } } -fn derive_identifiable(input: &syn2::ItemStruct) -> Result { +fn derive_identifiable(emitter: &mut Emitter, input: &IdDeriveInput) -> TokenStream { let name = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let (id_type, id_expr) = get_id_type(input)?; + let (id_type, id_expr) = get_id_type(emitter, input); - Ok(quote! { + quote! { impl #impl_generics Identifiable for #name #ty_generics #where_clause { type Id = #id_type; @@ -55,66 +148,42 @@ fn derive_identifiable(input: &syn2::ItemStruct) -> Result { #id_expr } } - }) + } } -fn get_id_type(input: &syn2::ItemStruct) -> Result<(TokenStream, TokenStream)> { - match &input.fields { - syn2::Fields::Named(fields) => { - for field in &fields.named { - let (field_name, field_ty) = (&field.ident, &field.ty); - - if is_identifier(&field.attrs) { - return Ok((quote! {#field_ty}, quote! {&self.#field_name})); - } - if is_transparent(&field.attrs) { - return Ok(( - quote! {<#field_ty as Identifiable>::Id}, - quote! {Identifiable::id(&self.#field_name)}, - )); - } +fn get_id_type(emitter: &mut Emitter, input: &IdDeriveInput) -> (syn2::Type, syn2::Expr) { + for (field_index, IdField { ty, ident, id_attr }) in input.fields().iter().enumerate() { + let field_name = ident.as_ref().map_or_else( + || syn2::Index::from(field_index).to_token_stream(), + ToTokens::to_token_stream, + ); + match id_attr { + IdAttr::Normal => { + return (ty.clone(), parse_quote! {&self.#field_name}); } - } - syn2::Fields::Unnamed(fields) => { - for (i, field) in fields.unnamed.iter().enumerate() { - let (field_id, field_ty): (syn2::Index, _) = (i.into(), &field.ty); - - if is_identifier(&field.attrs) { - return Ok((quote! {#field_ty}, quote! {&self.#field_id})); - } - if is_transparent(&field.attrs) { - return Ok(( - quote! {<#field_ty as Identifiable>::Id}, - quote! {Identifiable::id(&self.#field_id)}, - )); - } + IdAttr::Transparent => { + return ( + parse_quote! {<#ty as Identifiable>::Id}, + parse_quote! {Identifiable::id(&self.#field_name)}, + ); + } + IdAttr::Missing => { + // nothing here } } - syn2::Fields::Unit => {} } - match &input.fields { - syn2::Fields::Named(named) => { - for field in &named.named { - let field_ty = &field.ty; - - if field.ident.as_ref().expect("Field must be named") == "id" { - return Ok((quote! {#field_ty}, quote! {&self.id})); - } - } + for field in input.fields().iter() { + if field.ident.as_ref().is_some_and(|i| i == "id") { + return (field.ty.clone(), parse_quote! {&self.id}); } - syn2::Fields::Unnamed(_) | syn2::Fields::Unit => {} } - bail!(input, "Identifier not found") -} - -fn is_identifier(attrs: &[syn2::Attribute]) -> bool { - attrs.iter().any(|attr| attr == &parse_quote! {#[id]}) -} + emit!( + emitter, + "Could not find the identifier field. Either mark it with `#[id]` or have it named `id`" + ); -fn is_transparent(attrs: &[syn2::Attribute]) -> bool { - attrs - .iter() - .any(|attr| attr == &parse_quote! {#[id(transparent)]}) + // return dummy types + (parse_quote! {()}, parse_quote! {()}) } diff --git a/data_model/derive/src/lib.rs b/data_model/derive/src/lib.rs index 577404cdaa4..657a23633b5 100644 --- a/data_model/derive/src/lib.rs +++ b/data_model/derive/src/lib.rs @@ -230,10 +230,15 @@ pub fn model_single(input: TokenStream) -> TokenStream { /// #[manyhow] #[proc_macro_derive(IdEqOrdHash, attributes(id, opaque))] -pub fn id_eq_ord_hash(input: TokenStream) -> Result { - let input = syn2::parse2(input)?; +pub fn id_eq_ord_hash(input: TokenStream) -> TokenStream { + let mut emitter = Emitter::new(); - id::impl_id(&input) + let Some(input) = emitter.handle(syn2::parse2(input)) else { + return emitter.finish_token_stream(); + }; + + let result = id::impl_id_eq_ord_hash(&mut emitter, &input); + emitter.finish_token_stream_with(result) } /// [`Filter`] is used for code generation of `...Filter` structs and `...EventFilter` enums, as well as diff --git a/data_model/derive/tests/id_eq_ord_hash.rs b/data_model/derive/tests/id_eq_ord_hash.rs new file mode 100644 index 00000000000..91e94df415d --- /dev/null +++ b/data_model/derive/tests/id_eq_ord_hash.rs @@ -0,0 +1,117 @@ +//! Basic tests for traits derived by [`IdEqOrdHash`] macro + +use std::collections::BTreeSet; + +use iroha_data_model_derive::IdEqOrdHash; + +/// fake `Identifiable` trait +/// +/// Doesn't require `Into` implementation +pub trait Identifiable: Ord + Eq { + /// Type of the entity identifier + type Id: Ord + Eq + core::hash::Hash; + + /// Get reference to the type identifier + fn id(&self) -> &Self::Id; +} + +#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] +struct ObjectId(char); + +#[derive(Debug, IdEqOrdHash)] +struct Object { + id: ObjectId, + #[allow(unused)] + data: i32, +} +#[derive(Debug, IdEqOrdHash)] +struct ObjectWithExplicitId { + #[id] + definitely_not_id: ObjectId, + #[allow(unused)] + data: i32, +} +#[derive(Debug, IdEqOrdHash)] +struct ObjectWithTransparentId { + #[id(transparent)] // delegate the id to `Object` type + definitely_not_id: Object, + #[allow(unused)] + data: i32, +} + +// some objects to play with in tests +const ID_A: ObjectId = ObjectId('A'); +const ID_B: ObjectId = ObjectId('B'); +const OBJECT_1A: Object = Object { id: ID_A, data: 1 }; +const OBJECT_1B: Object = Object { id: ID_B, data: 1 }; +const OBJECT_2A: Object = Object { id: ID_A, data: 2 }; +const EXPLICIT_OBJECT_1A: ObjectWithExplicitId = ObjectWithExplicitId { + definitely_not_id: ID_A, + data: 1, +}; +const EXPLICIT_OBJECT_1B: ObjectWithExplicitId = ObjectWithExplicitId { + definitely_not_id: ID_B, + data: 1, +}; +const EXPLICIT_OBJECT_2A: ObjectWithExplicitId = ObjectWithExplicitId { + definitely_not_id: ID_A, + data: 2, +}; +const TRANSPARENT_OBJECT_1A: ObjectWithTransparentId = ObjectWithTransparentId { + definitely_not_id: OBJECT_1A, + data: 1, +}; +const TRANSPARENT_OBJECT_1B: ObjectWithTransparentId = ObjectWithTransparentId { + definitely_not_id: OBJECT_1B, + data: 1, +}; +const TRANSPARENT_OBJECT_2A: ObjectWithTransparentId = ObjectWithTransparentId { + definitely_not_id: OBJECT_2A, + data: 2, +}; + +#[test] +fn id() { + assert_eq!(OBJECT_1A.id(), &ID_A); + assert_eq!(OBJECT_1B.id(), &ID_B); + assert_eq!(EXPLICIT_OBJECT_1A.id(), &ID_A); + assert_eq!(EXPLICIT_OBJECT_1B.id(), &ID_B); + assert_eq!(TRANSPARENT_OBJECT_1A.id(), &ID_A); + assert_eq!(TRANSPARENT_OBJECT_1B.id(), &ID_B); +} + +#[test] +fn id_eq() { + assert_eq!(OBJECT_1A, OBJECT_2A); + assert_ne!(OBJECT_1B, OBJECT_2A); + assert_eq!(EXPLICIT_OBJECT_1A, EXPLICIT_OBJECT_2A); + assert_ne!(EXPLICIT_OBJECT_1B, EXPLICIT_OBJECT_2A); + assert_eq!(TRANSPARENT_OBJECT_1A, TRANSPARENT_OBJECT_2A); + assert_ne!(TRANSPARENT_OBJECT_1B, TRANSPARENT_OBJECT_2A); +} + +#[test] +fn id_ord() { + assert!(OBJECT_1A < OBJECT_1B); + assert!(OBJECT_1B > OBJECT_1A); + assert!(EXPLICIT_OBJECT_1A < EXPLICIT_OBJECT_1B); + assert!(EXPLICIT_OBJECT_1B > EXPLICIT_OBJECT_1A); + assert!(TRANSPARENT_OBJECT_1A < TRANSPARENT_OBJECT_1B); + assert!(TRANSPARENT_OBJECT_1B > TRANSPARENT_OBJECT_1A); +} + +#[test] +fn id_hash() { + let mut set = BTreeSet::new(); + set.insert(OBJECT_1A); + set.insert(OBJECT_2A); + assert_eq!(set.len(), 1); + assert!(set.contains(&OBJECT_1A)); + assert!(!set.contains(&OBJECT_1B)); + assert!(set.contains(&OBJECT_2A)); + set.insert(OBJECT_1B); + assert_eq!(set.len(), 2); + assert!(set.contains(&OBJECT_1A)); + assert!(set.contains(&OBJECT_1B)); + assert!(set.contains(&OBJECT_2A)); +} diff --git a/data_model/derive/tests/filter.rs b/data_model/derive/tests/ui_pass/filter.rs similarity index 95% rename from data_model/derive/tests/filter.rs rename to data_model/derive/tests/ui_pass/filter.rs index 27e54a056f1..94dccc72e95 100644 --- a/data_model/derive/tests/filter.rs +++ b/data_model/derive/tests/ui_pass/filter.rs @@ -103,8 +103,4 @@ pub enum LayerEvent { Created(LayerId), } -#[test] -fn filter() { - // nothing much to test here... - // I guess we do test that it compiles -} +fn main() {}