Skip to content

Commit

Permalink
TryFromRpcValue: Support enums
Browse files Browse the repository at this point in the history
  • Loading branch information
syyyr committed Jul 8, 2024
1 parent 46c1063 commit 6b06125
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 1 deletion.
91 changes: 90 additions & 1 deletion libshvproto-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use core::panic;

use proc_macro::TokenStream;
use quote::quote;

fn is_option(ty: &syn::Type) -> bool {
let syn::Type::Path(typepath) = ty else {
return false
Expand All @@ -26,6 +27,19 @@ fn is_option(ty: &syn::Type) -> bool {
matches!(data.args[0], syn::GenericArgument::Type(_))
}

fn is_type(ty: &syn::Type, to_match: &str) -> bool {
let syn::Type::Path(typepath) = ty else {
return false
};
if typepath.qself.is_some() {
return false
}
let Some(segment) = &typepath.path.segments.last() else {
return false;
};

segment.ident == to_match
}

#[proc_macro_derive(TryFromRpcValue, attributes(field_name))]
pub fn derive_from_rpcvalue(item: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -120,7 +134,82 @@ pub fn derive_from_rpcvalue(item: TokenStream) -> TokenStream {
}
}
}
},
syn::Data::Enum(syn::DataEnum { variants, .. }) => {
let mut try_froms = quote!{};
let mut match_arms = quote!{};
let mut expected_types = quote!{""};
for variant in variants {
let variant_ident = &variant.ident;
let mut add_type_matcher = |should, dest_variant_type, block| {
if should {
expected_types.extend(quote! {+ " " + stringify!(#dest_variant_type)});

try_froms.extend(quote!{
if let shvproto::Value::#dest_variant_type = value.value() {
return Ok(<#struct_identifier>::#variant_ident #block);
}
});
}
};
if let syn::Fields::Unnamed(variant_types) = &variant.fields {
if variant_types.unnamed.len() != 1 {
panic!("jde jenom jeden typ" );
}
match_arms.extend(quote!{
#struct_identifier::#variant_ident(val) => RpcValue::from(val),
});

let source_variant_type = &variant_types.unnamed.first().expect("variant_type").ty;
let deref_code = quote!((*x));
let unbox_code = quote!((x.as_ref().clone()));

add_type_matcher(is_type(source_variant_type, "i64"), quote!{Int(x)}, (deref_code).clone());
add_type_matcher(is_type(source_variant_type, "u64"), quote!{UInt(x)}, (deref_code).clone());
add_type_matcher(is_type(source_variant_type, "f64"), quote!{Double(x)}, (deref_code).clone());
add_type_matcher(is_type(source_variant_type, "bool"), quote!{Bool(x)}, (deref_code).clone());
add_type_matcher(is_type(source_variant_type, "DateTime"), quote!{DateTime(x)}, (deref_code).clone());
add_type_matcher(is_type(source_variant_type, "Decimal"), quote!{Decimal(x)}, (deref_code).clone());
add_type_matcher(is_type(source_variant_type, "String"), quote!{String(x)}, (unbox_code).clone());
add_type_matcher(is_type(source_variant_type, "Blob"), quote!{Blob(x)}, (unbox_code).clone());
add_type_matcher(is_type(source_variant_type, "List"), quote!{List(x)}, (unbox_code).clone());
add_type_matcher(is_type(source_variant_type, "Map"), quote!{Map(x)}, (unbox_code).clone());
add_type_matcher(is_type(source_variant_type, "IMap"), quote!{IMap(x)}, (unbox_code).clone());
}
if let syn::Fields::Unit = &variant.fields {
match_arms.extend(quote!{
#struct_identifier::#variant_ident => RpcValue::null(),
});
add_type_matcher(true, quote! {Null}, quote!());
}
}

quote!{
impl TryFrom<shvproto::RpcValue> for #struct_identifier {
type Error = String;
fn try_from(value: shvproto::RpcValue) -> Result<Self, Self::Error> {
#try_froms

Err("Couldn't deserialize ".to_owned() + stringify!(#struct_identifier) + ", expected types: " + #expected_types + ", got: " + value.type_name())
}
}

impl TryFrom<&shvproto::RpcValue> for #struct_identifier {
type Error = String;
fn try_from(value: &shvproto::RpcValue) -> Result<Self, Self::Error> {
value.clone().try_into()
}
}

impl From<#struct_identifier> for shvproto::RpcValue {
fn from(value: #struct_identifier) -> Self {
match value {
#match_arms
}
}
}
}
}
_ => panic!("This macro can only be used on a struct.")
_ => panic!("This macro can only be used on a struct or a enum.")
}.into()
}
39 changes: 39 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ mod test {
optional_int_field: Option<i32>
}

#[derive(Clone,Debug,PartialEq,TryFromRpcValue)]
pub enum AllVariants {
Null,
Int(i64),
UInt(u64),
Double(f64),
Bool(bool),
DateTime(shvproto::datetime::DateTime),
Decimal(shvproto::decimal::Decimal),
String(String),
Blob(shvproto::Blob),
List(shvproto::List),
Map(shvproto::Map),
IMap(shvproto::rpcvalue::IMap),
}


#[test]
fn derive_struct() {
let x: RpcValue = shvproto::make_map!(
Expand Down Expand Up @@ -75,4 +92,26 @@ mod test {
optional_int_field: Some(59)
});
}

#[test]
fn enum_field() {
let impl_test = |x: AllVariants| {
let y: RpcValue = x.clone().into();
let z: AllVariants = y.try_into().expect("Failed to parse");
assert_eq!(x, z);
};

impl_test(AllVariants::Null);
impl_test(AllVariants::Int(123));
impl_test(AllVariants::UInt(465));
impl_test(AllVariants::Double(123.0));
impl_test(AllVariants::Bool(true));
impl_test(AllVariants::DateTime(shvproto::DateTime::now()));
impl_test(AllVariants::Decimal(shvproto::Decimal::new(1234, 2)));
impl_test(AllVariants::String("Some string".to_owned()));
impl_test(AllVariants::Blob(vec![1, 2, 3]));
impl_test(AllVariants::List(vec![RpcValue::from("some_value")]));
impl_test(AllVariants::Map(shvproto::make_map!("key" => 1234)));
impl_test(AllVariants::IMap([(420, 111.into())].into_iter().collect::<BTreeMap<_,_>>()));
}
}

0 comments on commit 6b06125

Please sign in to comment.