Skip to content

Commit

Permalink
Respect trait bounds when deriving Display
Browse files Browse the repository at this point in the history
  • Loading branch information
greyblake committed Jun 9, 2024
1 parent c9df6d6 commit 537e73b
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 35 deletions.
16 changes: 14 additions & 2 deletions examples/any_generics/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
use nutype::nutype;
use std::borrow::Cow;

/// A wrapper around a vector that is guaranteed to be sorted.
#[nutype(
sanitize(with = |mut v| { v.sort(); v }),
derive(Debug)
)]
struct SortedVec<T: Ord>(Vec<T>);

/// A wrapper around a vector that is guaranteed to be non-empty.
#[nutype(
validate(predicate = |vec| !vec.is_empty()),
derive(Debug),
)]
struct NotEmpty<T>(Vec<T>);

/// An example with lifetimes
#[nutype(derive(
Debug,
Display,
Expand All @@ -33,10 +42,13 @@ struct Clarabelle<'a>(Cow<'a, str>);

fn main() {
{
let v = NotEmpty::new(vec![1, 2, 3]).unwrap();
assert_eq!(v.into_inner(), vec![1, 2, 3]);
let v = SortedVec::new(vec![3, 0, 2]);
assert_eq!(v.into_inner(), vec![0, 2, 3]);
}
{
let v = NotEmpty::new(vec![1, 2, 3]).unwrap();
assert_eq!(v.into_inner(), vec![1, 2, 3]);

let err = NotEmpty::<i32>::new(vec![]).unwrap_err();
assert_eq!(err, NotEmptyError::PredicateViolated);
}
Expand Down
17 changes: 11 additions & 6 deletions nutype_macros/src/common/gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ pub fn gen_impl_into_inner(
generics: &Generics,
inner_type: impl ToTokens,
) -> TokenStream {
// TODO: Consider stripping bounds only once instead of doing it every time?
let generics_without_bounds = strip_trait_bounds(generics);
let generics_without_bounds = strip_trait_bounds_on_generics(generics);
quote! {
impl #generics #type_name #generics_without_bounds {
#[inline]
Expand All @@ -150,8 +149,14 @@ pub fn gen_impl_into_inner(
}
}

// TODO: Move to utils?
fn strip_trait_bounds(original: &Generics) -> Generics {
/// Remove trait bounds from generics.
///
/// Input:
/// <T: Display + Debug, U: Clone>
///
/// Output:
/// <T, U>
fn strip_trait_bounds_on_generics(original: &Generics) -> Generics {
let mut generics = original.clone();
for param in &mut generics.params {
if let syn::GenericParam::Type(syn::TypeParam { bounds, .. }) = param {
Expand Down Expand Up @@ -210,7 +215,7 @@ pub trait GenerateNewtype {
sanitizers: &[Self::Sanitizer],
validators: &[Self::Validator],
) -> TokenStream {
let generics_without_bounds = strip_trait_bounds(generics);
let generics_without_bounds = strip_trait_bounds_on_generics(generics);
let fn_sanitize = Self::gen_fn_sanitize(inner_type, sanitizers);
let validation_error = Self::gen_validation_error_type(type_name, validators);
let error_type_name = gen_error_type_name(type_name);
Expand Down Expand Up @@ -251,7 +256,7 @@ pub trait GenerateNewtype {
inner_type: &Self::InnerType,
sanitizers: &[Self::Sanitizer],
) -> TokenStream {
let generics_without_bounds = strip_trait_bounds(generics);
let generics_without_bounds = strip_trait_bounds_on_generics(generics);
let fn_sanitize = Self::gen_fn_sanitize(inner_type, sanitizers);

let (input_type, convert_raw_value_if_necessary) = if Self::NEW_CONVERT_INTO_INNER_TYPE {
Expand Down
8 changes: 6 additions & 2 deletions nutype_macros/src/common/gen/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::Generics;

use crate::common::models::{ErrorTypeName, InnerType, TypeName};
use crate::common::{
gen::strip_trait_bounds_on_generics,
models::{ErrorTypeName, InnerType, TypeName},
};

use super::parse_error::{gen_def_parse_error, gen_parse_error_name};

Expand Down Expand Up @@ -106,8 +109,9 @@ pub fn gen_impl_trait_deref(
}

pub fn gen_impl_trait_display(type_name: &TypeName, generics: &Generics) -> TokenStream {
let generics_without_bounds = strip_trait_bounds_on_generics(generics);
quote! {
impl #generics ::core::fmt::Display for #type_name #generics {
impl #generics ::core::fmt::Display for #type_name #generics_without_bounds {
#[inline]
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
// A tiny wrapper function with trait boundary that improves error reporting.
Expand Down
144 changes: 119 additions & 25 deletions test_suite/tests/any.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use nutype::nutype;
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{Debug, Display};
use std::hash::Hash;
use test_suite::test_helpers::traits::*;

// Inner custom type, which is unknown to nutype
Expand Down Expand Up @@ -506,31 +508,123 @@ mod with_generics {
}
}

// TODO
// #[test]
// fn test_generic_with_boundaries_and_sanitize() {
// #[nutype(
// sanitize(with = |v| { v.sort(); v }),
// derive(Debug)
// )]
// struct SortedVec<T: Ord>(Vec<T>);

// {
// let vec = NonEmptyVec::new(vec![1, 2, 3]).unwrap();
// assert_eq!(vec.into_inner(), vec![1, 2, 3]);
// }

// {
// let vec = NonEmptyVec::new(vec![5]).unwrap();
// assert_eq!(vec.into_inner(), vec![5]);
// }

// {
// let vec: Vec<u8> = vec![];
// let err = NonEmptyVec::new(vec).unwrap_err();
// assert_eq!(err, NonEmptyVecError::PredicateViolated);
// }
// }
#[test]
fn test_generic_with_boundaries_and_sanitize() {
#[nutype(
sanitize(with = |mut v| { v.sort(); v }),
derive(Debug)
)]
struct SortedVec<T: Ord>(Vec<T>);

let sorted = SortedVec::new(vec![3, 1, 2]);
assert_eq!(sorted.into_inner(), vec![1, 2, 3]);
}

#[test]
fn test_generic_with_boundaries_and_many_derives() {
// The point of this test is to ensure that the generate code can be compiled at least
// with respect to the specified trait boundaries

// #[nutype(
// derive(Debug)
// )]
// struct Wrapper1<A: Hash + Eq + Clone, B: Ord>(HashMap<A, B>);
}

#[test]
fn test_generic_boundaries_debug() {
#[nutype(derive(Debug))]
struct WrapperDebug<T: Debug>(T);

let w = WrapperDebug::new(13);
assert_eq!(format!("{w:?}"), "WrapperDebug(13)");
}

#[test]
fn test_generic_boundaries_display() {
#[nutype(derive(Debug, Display))]
struct WrapperDisplay<T: Debug + Display>(T);

let number = WrapperDisplay::new(5);
assert_eq!(number.to_string(), "5");

let b = WrapperDisplay::new(true);
assert_eq!(b.to_string(), "true");
}

#[test]
fn test_generic_boundaries_clone() {
// TODO
}

#[test]
fn test_generic_boundaries_copy() {
// TODO
}

#[test]
fn test_generic_boundaries_partial_eq() {
// TODO
}

#[test]
fn test_generic_boundaries_eq() {
// TODO
}

#[test]
fn test_generic_boundaries_partial_ord() {
// TODO
}

#[test]
fn test_generic_boundaries_ord() {
// TODO
}

#[test]
fn test_generic_boundaries_hash() {
// TODO
}

#[test]
fn test_generic_boundaries_serialize() {
// TODO
}

#[test]
fn test_generic_boundaries_deserialize() {
// TODO
}

#[test]
fn test_generic_boundaries_from_str() {
// TODO
}

#[test]
fn test_generic_boundaries_arbitrary() {
// TODO
}

#[test]
fn test_generic_with_boundaries_and_sanitize_and_validate() {
#[nutype(
validate(predicate = |v| !v.is_empty()),
sanitize(with = |mut v| { v.sort(); v }),
derive(Debug)
)]
struct NonEmptySortedVec<T: Ord>(Vec<T>);

{
let err = NonEmptySortedVec::<i32>::new(vec![]).unwrap_err();
assert_eq!(err, NonEmptySortedVecError::PredicateViolated);
}
{
let vec = NonEmptySortedVec::new(vec![3, 1, 2]).unwrap();
assert_eq!(vec.into_inner(), vec![1, 2, 3]);
}
}

#[test]
fn test_generic_with_lifetime_cow() {
Expand Down

0 comments on commit 537e73b

Please sign in to comment.