Skip to content

Commit

Permalink
[refactor] #3882: Make derive(IdEqOrdHash) use darling, add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Nikita Strygin <[email protected]>
  • Loading branch information
DCNick3 committed Sep 26, 2023
1 parent 0eb7f77 commit f5446ed
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 68 deletions.
189 changes: 129 additions & 60 deletions data_model/derive/src/id.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,109 @@
#![allow(clippy::str_to_string, clippy::mixed_read_write_in_expression)]

use manyhow::{bail, Result};
use darling::{FromAttributes, FromDeriveInput, FromField};
use iroha_macro_utils::Emitter;
use manyhow::emit;
use proc_macro2::TokenStream;
use quote::quote;
use quote::{quote, ToTokens};
use syn2::parse_quote;

pub fn impl_id(input: &syn2::ItemStruct) -> Result<TokenStream> {
mod kw {
syn2::custom_keyword!(transparent);
}

enum IdAttr {
Missing,
Normal,
Transparent,
}

impl FromAttributes for IdAttr {
fn from_attributes(attrs: &[syn2::Attribute]) -> darling::Result<Self> {
let mut accumulator = darling::error::Accumulator::default();
let attrs = attrs
.iter()
.filter(|v| v.path().is_ident("id"))
.collect::<Vec<_>>();
let attr = match attrs.as_slice() {
[] => {
return accumulator.finish_with(IdAttr::Missing);
}
[attr] => attr,
[attr, ref tail @ ..] => {
accumulator.push(
darling::Error::custom("Only one `#[id]` attribute is allowed!").with_span(
&tail
.iter()
.map(syn2::spanned::Spanned::span)
.reduce(|a, b| a.join(b).unwrap())
.unwrap(),
),
);
attr
}
};

let result = match &attr.meta {
syn2::Meta::Path(_) => IdAttr::Normal,
syn2::Meta::List(list) if list.parse_args::<kw::transparent>().is_ok() => {
IdAttr::Transparent
}
_ => {
accumulator.push(
darling::Error::custom("Expected `#[id]` or `#[id(transparent)]`")
.with_span(&attr),
);
IdAttr::Normal
}
};

accumulator.finish_with(result)
}
}

#[derive(FromDeriveInput)]
#[darling(supports(struct_any))]
struct IdDeriveInput {
ident: syn2::Ident,
generics: syn2::Generics,
data: darling::ast::Data<darling::util::Ignored, IdField>,
}

struct IdField {
ident: Option<syn2::Ident>,
ty: syn2::Type,
id_attr: IdAttr,
}

impl FromField for IdField {
fn from_field(field: &syn2::Field) -> darling::Result<Self> {
let ident = field.ident.clone();
let ty = field.ty.clone();
let id_attr = IdAttr::from_attributes(&field.attrs)?;

Ok(Self { ident, ty, id_attr })
}
}

impl IdDeriveInput {
fn fields(&self) -> &darling::ast::Fields<IdField> {
match &self.data {
darling::ast::Data::Struct(fields) => fields,
_ => unreachable!(),
}
}
}

pub fn impl_id_eq_ord_hash(emitter: &mut Emitter, input: &syn2::DeriveInput) -> TokenStream {
let Some(input) = emitter.handle(IdDeriveInput::from_derive_input(input)) else {
return quote!();
};

let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let identifiable_derive = derive_identifiable(input)?;
let identifiable_derive = derive_identifiable(emitter, &input);

Ok(quote! {
quote! {
#identifiable_derive

impl #impl_generics ::core::cmp::PartialOrd for #name #ty_generics #where_clause where Self: Identifiable {
Expand Down Expand Up @@ -38,15 +131,15 @@ pub fn impl_id(input: &syn2::ItemStruct) -> Result<TokenStream> {
self.id().hash(state);
}
}
})
}
}

fn derive_identifiable(input: &syn2::ItemStruct) -> Result<TokenStream> {
fn derive_identifiable(emitter: &mut Emitter, input: &IdDeriveInput) -> TokenStream {
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let (id_type, id_expr) = get_id_type(input)?;
let (id_type, id_expr) = get_id_type(emitter, input);

Ok(quote! {
quote! {
impl #impl_generics Identifiable for #name #ty_generics #where_clause {
type Id = #id_type;

Expand All @@ -55,66 +148,42 @@ fn derive_identifiable(input: &syn2::ItemStruct) -> Result<TokenStream> {
#id_expr
}
}
})
}
}

fn get_id_type(input: &syn2::ItemStruct) -> Result<(TokenStream, TokenStream)> {
match &input.fields {
syn2::Fields::Named(fields) => {
for field in &fields.named {
let (field_name, field_ty) = (&field.ident, &field.ty);

if is_identifier(&field.attrs) {
return Ok((quote! {#field_ty}, quote! {&self.#field_name}));
}
if is_transparent(&field.attrs) {
return Ok((
quote! {<#field_ty as Identifiable>::Id},
quote! {Identifiable::id(&self.#field_name)},
));
}
fn get_id_type(emitter: &mut Emitter, input: &IdDeriveInput) -> (syn2::Type, syn2::Expr) {
for (field_index, IdField { ty, ident, id_attr }) in input.fields().iter().enumerate() {
let field_name = ident.as_ref().map_or_else(
|| syn2::Index::from(field_index).to_token_stream(),
ToTokens::to_token_stream,
);
match id_attr {
IdAttr::Normal => {
return (ty.clone(), parse_quote! {&self.#field_name});
}
}
syn2::Fields::Unnamed(fields) => {
for (i, field) in fields.unnamed.iter().enumerate() {
let (field_id, field_ty): (syn2::Index, _) = (i.into(), &field.ty);

if is_identifier(&field.attrs) {
return Ok((quote! {#field_ty}, quote! {&self.#field_id}));
}
if is_transparent(&field.attrs) {
return Ok((
quote! {<#field_ty as Identifiable>::Id},
quote! {Identifiable::id(&self.#field_id)},
));
}
IdAttr::Transparent => {
return (
parse_quote! {<#ty as Identifiable>::Id},
parse_quote! {Identifiable::id(&self.#field_name)},
);
}
IdAttr::Missing => {
// nothing here
}
}
syn2::Fields::Unit => {}
}

match &input.fields {
syn2::Fields::Named(named) => {
for field in &named.named {
let field_ty = &field.ty;

if field.ident.as_ref().expect("Field must be named") == "id" {
return Ok((quote! {#field_ty}, quote! {&self.id}));
}
}
for field in input.fields().iter() {
if field.ident.as_ref().is_some_and(|i| i == "id") {
return (field.ty.clone(), parse_quote! {&self.id});
}
syn2::Fields::Unnamed(_) | syn2::Fields::Unit => {}
}

bail!(input, "Identifier not found")
}

fn is_identifier(attrs: &[syn2::Attribute]) -> bool {
attrs.iter().any(|attr| attr == &parse_quote! {#[id]})
}
emit!(
emitter,
"Could not find the identifier field. Either mark it with `#[id]` or have it named `id`"
);

fn is_transparent(attrs: &[syn2::Attribute]) -> bool {
attrs
.iter()
.any(|attr| attr == &parse_quote! {#[id(transparent)]})
// return dummy types
(parse_quote! {()}, parse_quote! {()})
}
11 changes: 8 additions & 3 deletions data_model/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,15 @@ pub fn model_single(input: TokenStream) -> TokenStream {
///
#[manyhow]
#[proc_macro_derive(IdEqOrdHash, attributes(id, opaque))]
pub fn id_eq_ord_hash(input: TokenStream) -> Result<TokenStream> {
let input = syn2::parse2(input)?;
pub fn id_eq_ord_hash(input: TokenStream) -> TokenStream {
let mut emitter = Emitter::new();

id::impl_id(&input)
let Some(input) = emitter.handle(syn2::parse2(input)) else {
return emitter.finish_token_stream();
};

let result = id::impl_id_eq_ord_hash(&mut emitter, &input);
emitter.finish_token_stream_with(result)
}

/// [`Filter`] is used for code generation of `...Filter` structs and `...EventFilter` enums, as well as
Expand Down
117 changes: 117 additions & 0 deletions data_model/derive/tests/id_eq_ord_hash.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
//! Basic tests for traits derived by [`IdEqOrdHash`] macro
use std::collections::BTreeSet;

use iroha_data_model_derive::IdEqOrdHash;

/// fake `Identifiable` trait
///
/// Doesn't require `Into<IdBox>` implementation
pub trait Identifiable: Ord + Eq {
/// Type of the entity identifier
type Id: Ord + Eq + core::hash::Hash;

/// Get reference to the type identifier
fn id(&self) -> &Self::Id;
}

#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
struct ObjectId(char);

#[derive(Debug, IdEqOrdHash)]
struct Object {
id: ObjectId,
#[allow(unused)]
data: i32,
}
#[derive(Debug, IdEqOrdHash)]
struct ObjectWithExplicitId {
#[id]
definitely_not_id: ObjectId,
#[allow(unused)]
data: i32,
}
#[derive(Debug, IdEqOrdHash)]
struct ObjectWithTransparentId {
#[id(transparent)] // delegate the id to `Object` type
definitely_not_id: Object,
#[allow(unused)]
data: i32,
}

// some objects to play with in tests
const ID_A: ObjectId = ObjectId('A');
const ID_B: ObjectId = ObjectId('B');
const OBJECT_1A: Object = Object { id: ID_A, data: 1 };
const OBJECT_1B: Object = Object { id: ID_B, data: 1 };
const OBJECT_2A: Object = Object { id: ID_A, data: 2 };
const EXPLICIT_OBJECT_1A: ObjectWithExplicitId = ObjectWithExplicitId {
definitely_not_id: ID_A,
data: 1,
};
const EXPLICIT_OBJECT_1B: ObjectWithExplicitId = ObjectWithExplicitId {
definitely_not_id: ID_B,
data: 1,
};
const EXPLICIT_OBJECT_2A: ObjectWithExplicitId = ObjectWithExplicitId {
definitely_not_id: ID_A,
data: 2,
};
const TRANSPARENT_OBJECT_1A: ObjectWithTransparentId = ObjectWithTransparentId {
definitely_not_id: OBJECT_1A,
data: 1,
};
const TRANSPARENT_OBJECT_1B: ObjectWithTransparentId = ObjectWithTransparentId {
definitely_not_id: OBJECT_1B,
data: 1,
};
const TRANSPARENT_OBJECT_2A: ObjectWithTransparentId = ObjectWithTransparentId {
definitely_not_id: OBJECT_2A,
data: 2,
};

#[test]
fn id() {
assert_eq!(OBJECT_1A.id(), &ID_A);
assert_eq!(OBJECT_1B.id(), &ID_B);
assert_eq!(EXPLICIT_OBJECT_1A.id(), &ID_A);
assert_eq!(EXPLICIT_OBJECT_1B.id(), &ID_B);
assert_eq!(TRANSPARENT_OBJECT_1A.id(), &ID_A);
assert_eq!(TRANSPARENT_OBJECT_1B.id(), &ID_B);
}

#[test]
fn id_eq() {
assert_eq!(OBJECT_1A, OBJECT_2A);
assert_ne!(OBJECT_1B, OBJECT_2A);
assert_eq!(EXPLICIT_OBJECT_1A, EXPLICIT_OBJECT_2A);
assert_ne!(EXPLICIT_OBJECT_1B, EXPLICIT_OBJECT_2A);
assert_eq!(TRANSPARENT_OBJECT_1A, TRANSPARENT_OBJECT_2A);
assert_ne!(TRANSPARENT_OBJECT_1B, TRANSPARENT_OBJECT_2A);
}

#[test]
fn id_ord() {
assert!(OBJECT_1A < OBJECT_1B);
assert!(OBJECT_1B > OBJECT_1A);
assert!(EXPLICIT_OBJECT_1A < EXPLICIT_OBJECT_1B);
assert!(EXPLICIT_OBJECT_1B > EXPLICIT_OBJECT_1A);
assert!(TRANSPARENT_OBJECT_1A < TRANSPARENT_OBJECT_1B);
assert!(TRANSPARENT_OBJECT_1B > TRANSPARENT_OBJECT_1A);
}

#[test]
fn id_hash() {
let mut set = BTreeSet::new();
set.insert(OBJECT_1A);
set.insert(OBJECT_2A);
assert_eq!(set.len(), 1);
assert!(set.contains(&OBJECT_1A));
assert!(!set.contains(&OBJECT_1B));
assert!(set.contains(&OBJECT_2A));
set.insert(OBJECT_1B);
assert_eq!(set.len(), 2);
assert!(set.contains(&OBJECT_1A));
assert!(set.contains(&OBJECT_1B));
assert!(set.contains(&OBJECT_2A));
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,4 @@ pub enum LayerEvent {
Created(LayerId),
}

#[test]
fn filter() {
// nothing much to test here...
// I guess we do test that it compiles
}
fn main() {}

0 comments on commit f5446ed

Please sign in to comment.