Skip to content

Commit

Permalink
TryFromRpcValue: Allow a custom struct in enum
Browse files Browse the repository at this point in the history
  • Loading branch information
syyyr committed Jul 9, 2024
1 parent 6c87d24 commit 3d6538a
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 36 deletions.
73 changes: 44 additions & 29 deletions libshvproto-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,16 @@ fn is_option(ty: &syn::Type) -> bool {
matches!(data.args[0], syn::GenericArgument::Type(_))
}

fn is_type(ty: &syn::Type, to_match: &str) -> bool {
fn get_type(ty: &syn::Type) -> Option<String> {
let syn::Type::Path(typepath) = ty else {
return false
return None
};
if typepath.qself.is_some() {
return false
return None
}
let Some(segment) = typepath.path.segments.last() else {
return false;
};
let segment = typepath.path.segments.last()?;

segment.ident == to_match
Some(segment.ident.to_string())
}

#[proc_macro_derive(TryFromRpcValue, attributes(field_name))]
Expand Down Expand Up @@ -136,19 +134,18 @@ pub fn derive_from_rpcvalue(item: TokenStream) -> TokenStream {
}
},
syn::Data::Enum(syn::DataEnum { variants, .. }) => {
let mut match_arms_de = quote!{};
let mut match_arms_ser = quote!{};
let mut match_arms_de = vec![];
let mut match_arms_ser = quote!{};
let mut allowed_types = vec![];
let mut map_has_been_matched: Option<(proc_macro2::TokenStream, proc_macro2::TokenStream)> = None;
for variant in variants {
let variant_ident = &variant.ident;
let mut add_type_matcher = |should, dest_variant_type, block| {
if should {
allowed_types.push(quote!{stringify!(#dest_variant_type)});
let mut add_type_matcher = |match_arms: &mut Vec<proc_macro2::TokenStream>, rpcvalue_variant_type, block| {
allowed_types.push(quote!{stringify!(#rpcvalue_variant_type)});

match_arms_de.extend(quote!{
shvproto::Value::#dest_variant_type => Ok(<#struct_identifier>::#variant_ident #block),
});
}
match_arms.push(quote!{
shvproto::Value::#rpcvalue_variant_type => Ok(<#struct_identifier>::#variant_ident #block)
});
};
if let syn::Fields::Unnamed(variant_types) = &variant.fields {
if variant_types.unnamed.len() != 1 {
Expand All @@ -162,23 +159,41 @@ pub fn derive_from_rpcvalue(item: TokenStream) -> TokenStream {
let deref_code = quote!((*x));
let unbox_code = quote!((x.as_ref().clone()));

add_type_matcher(is_type(source_variant_type, "i64"), quote!{Int(x)}, deref_code.clone());
add_type_matcher(is_type(source_variant_type, "u64"), quote!{UInt(x)}, deref_code.clone());
add_type_matcher(is_type(source_variant_type, "f64"), quote!{Double(x)}, deref_code.clone());
add_type_matcher(is_type(source_variant_type, "bool"), quote!{Bool(x)}, deref_code.clone());
add_type_matcher(is_type(source_variant_type, "DateTime"), quote!{DateTime(x)}, deref_code.clone());
add_type_matcher(is_type(source_variant_type, "Decimal"), quote!{Decimal(x)}, deref_code.clone());
add_type_matcher(is_type(source_variant_type, "String"), quote!{String(x)}, unbox_code.clone());
add_type_matcher(is_type(source_variant_type, "Blob"), quote!{Blob(x)}, unbox_code.clone());
add_type_matcher(is_type(source_variant_type, "List"), quote!{List(x)}, unbox_code.clone());
add_type_matcher(is_type(source_variant_type, "Map"), quote!{Map(x)}, unbox_code.clone());
add_type_matcher(is_type(source_variant_type, "IMap"), quote!{IMap(x)}, unbox_code.clone());
if let Some(type_identifier) = get_type(source_variant_type) {
match type_identifier.as_ref() {
"i64" => add_type_matcher(&mut match_arms_de, quote!{Int(x)}, deref_code.clone()),
"u64" => add_type_matcher(&mut match_arms_de, quote!{UInt(x)}, deref_code.clone()),
"f64" => add_type_matcher(&mut match_arms_de, quote!{Double(x)}, deref_code.clone()),
"bool" => add_type_matcher(&mut match_arms_de, quote!{Bool(x)}, deref_code.clone()),
"DateTime" => add_type_matcher(&mut match_arms_de, quote!{DateTime(x)}, deref_code.clone()),
"Decimal" => add_type_matcher(&mut match_arms_de, quote!{Decimal(x)}, deref_code.clone()),
"String" => add_type_matcher(&mut match_arms_de, quote!{String(x)}, unbox_code.clone()),
"Blob" => add_type_matcher(&mut match_arms_de, quote!{Blob(x)}, unbox_code.clone()),
"List" => add_type_matcher(&mut match_arms_de, quote!{List(x)}, unbox_code.clone()),
"Map" => {
if let Some((matched_variant_ident, matched_variant_type)) = map_has_been_matched {
panic!("Can't match enum variant {}(Map), because a Map will already be matched as {}({})", quote!{#variant_ident}, quote!{#matched_variant_ident}, quote!{#matched_variant_type});
}
add_type_matcher(&mut match_arms_de, quote!{Map(x)}, unbox_code.clone());
map_has_been_matched = Some((quote!(#variant_ident), quote!{#source_variant_type}));
},
"IMap" => add_type_matcher(&mut match_arms_de, quote!{IMap(x)}, unbox_code.clone()),
_ => {
if let Some((matched_variant_ident, matched_variant_type)) = map_has_been_matched {
panic!("Can't match enum variant {}({}) as a Map, because a Map will already be matched as {}({})", quote!{#variant_ident}, quote!{#source_variant_type}, quote!{#matched_variant_ident}, quote!{#matched_variant_type});
}
add_type_matcher(&mut match_arms_de, quote! {Map(x)}, quote!((x.as_ref().clone().try_into()?)));
map_has_been_matched = Some((quote!(#variant_ident), quote!{#source_variant_type}));
}

}
}
}
if let syn::Fields::Unit = &variant.fields {
match_arms_ser.extend(quote!{
#struct_identifier::#variant_ident => RpcValue::null(),
});
add_type_matcher(true, quote! {Null}, quote!());
add_type_matcher(&mut match_arms_de, quote! {Null}, quote!());
}
}

Expand All @@ -187,7 +202,7 @@ pub fn derive_from_rpcvalue(item: TokenStream) -> TokenStream {
type Error = String;
fn try_from(value: shvproto::RpcValue) -> Result<Self, Self::Error> {
match value.value() {
#match_arms_de
#(#match_arms_de),*,
_ => Err("Couldn't deserialize into '".to_owned() + stringify!(#struct_identifier) + "' enum, allowed types: " + [#(#allowed_types),*].join("|").as_ref() + ", got: " + value.type_name())
}
}
Expand Down
28 changes: 21 additions & 7 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ mod test {
use libshvproto_macros::TryFromRpcValue;
use shvproto::RpcValue;

#[derive(Debug,PartialEq,TryFromRpcValue)]
struct EmptyStruct {
#[derive(Clone,Debug,PartialEq,TryFromRpcValue)]
pub struct EmptyStruct {
}

#[derive(Debug,PartialEq,TryFromRpcValue)]
struct OneFieldStruct {
#[derive(Clone,Debug,PartialEq,TryFromRpcValue)]
pub struct OneFieldStruct {
x: i32
}

#[derive(Debug,PartialEq,TryFromRpcValue)]
struct TestStruct {
pub struct TestStruct {
int_field: i32,
#[field_name = "myCustomFieldName"] int_field_with_custom_field_name: i32,
string_field: String,
Expand Down Expand Up @@ -49,6 +49,11 @@ mod test {
IMap(shvproto::rpcvalue::IMap),
}

#[derive(Clone,Debug,PartialEq,TryFromRpcValue)]
pub enum EnumWithUserStruct {
OneFieldStructVariant(OneFieldStruct),
IntVariant(i64)
}

#[test]
fn derive_struct() {
Expand All @@ -66,7 +71,7 @@ mod test {
"mapIntField" => [("aaa".to_string(), 111)].into_iter().collect::<BTreeMap<_,_>>(),
"imapField" => [(420, 111)].into_iter().collect::<BTreeMap<_,_>>(),
).into();

let y: TestStruct = x.clone().try_into().expect("Failed to parse");
assert_eq!(x, y.into());
}
Expand Down Expand Up @@ -100,7 +105,7 @@ mod test {
let z: AllVariants = y.try_into().expect("Failed to parse");
assert_eq!(x, z);
};

impl_test(AllVariants::Null);
impl_test(AllVariants::Int(123));
impl_test(AllVariants::UInt(465));
Expand All @@ -113,5 +118,14 @@ mod test {
impl_test(AllVariants::List(vec![RpcValue::from("some_value")]));
impl_test(AllVariants::Map(shvproto::make_map!("key" => 1234)));
impl_test(AllVariants::IMap([(420, 111.into())].into_iter().collect::<BTreeMap<_,_>>()));

let impl_test = |x: EnumWithUserStruct| {
let y: RpcValue = x.clone().into();
let z: EnumWithUserStruct = y.try_into().expect("Failed to parse");
assert_eq!(x, z);
};

impl_test(EnumWithUserStruct::OneFieldStructVariant(OneFieldStruct{x: 123}));
impl_test(EnumWithUserStruct::IntVariant(123));
}
}

0 comments on commit 3d6538a

Please sign in to comment.