diff --git a/validator/src/types.rs b/validator/src/types.rs index 9bf49da6..9bd06bc8 100644 --- a/validator/src/types.rs +++ b/validator/src/types.rs @@ -87,6 +87,19 @@ impl ValidationErrors { } } + pub fn merge_self_flatten( + &mut self, + child: Result<(), ValidationErrors>, + ) -> &mut ValidationErrors { + match child { + Ok(()) => self, + Err(errors) => { + self.0.extend(errors.0); + self + } + } + } + /// Returns the combined outcome of a struct's validation result along with the nested /// validation result for one of its fields. pub fn merge( diff --git a/validator_derive/src/case.rs b/validator_derive/src/case.rs new file mode 100644 index 00000000..417405ef --- /dev/null +++ b/validator_derive/src/case.rs @@ -0,0 +1,66 @@ +// Adapted from serde's code: +// https://github.com/serde-rs/serde/blob/d1790205/serde_derive/src/internals/case.rs + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub enum RenameRule { + #[default] + None, + LowerCase, + UpperCase, + PascalCase, + CamelCase, + SnakeCase, + ScreamingSnakeCase, + KebabCase, + ScreamingKebabCase, +} + +impl RenameRule { + pub fn parse(s: syn::LitStr) -> Option { + match &*s.value() { + "lowercase" => Some(Self::LowerCase), + "UPPERCASE" => Some(Self::UpperCase), + "PascalCase" => Some(Self::PascalCase), + "camelCase" => Some(Self::CamelCase), + "snake_case" => Some(Self::SnakeCase), + "SCREAMING_SNAKE_CASE" => Some(Self::ScreamingSnakeCase), + "kebab-case" => Some(Self::KebabCase), + "SCREAMING-KEBAB-CASE" => Some(Self::ScreamingKebabCase), + _ => None, + } + } + + pub fn apply(&self, orig: &str) -> String { + match self { + Self::None | Self::LowerCase | Self::SnakeCase => orig.to_owned(), + Self::ScreamingSnakeCase => orig.to_ascii_uppercase(), + Self::UpperCase => orig.to_ascii_lowercase(), + Self::PascalCase => pascal_case(orig, true), + Self::CamelCase => pascal_case(orig, false), + Self::KebabCase => orig.replace("_", "-"), + Self::ScreamingKebabCase => orig.to_ascii_uppercase().replace("_", "-"), + } + } + + pub fn is_some(&self) -> bool { + !matches!(self, Self::None) + } +} + +/// Converts a snake_case field name into PascalCase or camelCase. +/// `cap_start` specifies if the first character is capitalized (as in Pascal). +fn pascal_case(orig: &str, cap_start: bool) -> String { + let mut pascal = String::new(); + let mut capitalize = cap_start; + for ch in orig.chars() { + if ch == '_' { + capitalize = true; + } else if capitalize { + pascal.push(ch.to_ascii_uppercase()); + capitalize = false; + } else { + pascal.push(ch); + } + } + pascal +} diff --git a/validator_derive/src/lib.rs b/validator_derive/src/lib.rs index 44a896e2..fbb5298c 100644 --- a/validator_derive/src/lib.rs +++ b/validator_derive/src/lib.rs @@ -3,8 +3,12 @@ use darling::util::{Override, WithOriginal}; use darling::FromDeriveInput; use proc_macro_error2::{abort, proc_macro_error}; use quote::{quote, ToTokens}; -use syn::{parse_macro_input, DeriveInput, Field, GenericParam, Path, PathArguments}; +use syn::meta::ParseNestedMeta; +use syn::{ + parse_macro_input, DeriveInput, Field, GenericParam, LitStr, Path, PathArguments, Token, +}; +use case::RenameRule; use tokens::cards::credit_card_tokens; use tokens::contains::contains_tokens; use tokens::custom::custom_tokens; @@ -23,6 +27,7 @@ use tokens::url::url_tokens; use types::*; use utils::{quote_use_stmts, CrateName}; +mod case; mod tokens; mod types; mod utils; @@ -30,7 +35,10 @@ mod utils; impl ToTokens for ValidateField { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { let field_name = self.ident.clone().unwrap(); - let field_name_str = self.ident.clone().unwrap().to_string(); + let field_name_str = match &self.rename { + Some(rename) => rename.clone(), + None => self.ident.as_ref().unwrap().to_string(), + }; let type_name = self.ty.to_token_stream().to_string(); let is_number = NUMBER_TYPES.contains(&type_name); @@ -211,7 +219,7 @@ impl ToTokens for ValidateField { let nested = if let Some(n) = self.nested { if n { - wrapper_closure(nested_tokens(&actual_field, &field_name_str)) + wrapper_closure(nested_tokens(&actual_field, &field_name_str, self.flatten)) } else { quote!() } @@ -258,6 +266,11 @@ struct ValidationData { crate_name: CrateName, } +#[derive(Debug, Default, Clone)] +struct SerdeData { + rename_all: RenameRule, +} + impl ValidationData { fn validate(self) -> darling::Result { if let Some(context) = &self.context { @@ -280,8 +293,7 @@ impl ValidationData { } if let Data::Struct(fields) = &self.data { - let original_fields: Vec<&Field> = - fields.fields.iter().map(|f| &f.original).collect(); + let original_fields: Vec<&Field> = fields.fields.iter().map(|f| &f.original).collect(); for f in &fields.fields { f.parsed.validate(&self.ident, &original_fields, &f.original); } @@ -296,6 +308,10 @@ impl ValidationData { pub fn derive_validation(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input: DeriveInput = parse_macro_input!(input); + // parse struct-level serde attributes + // i.e. rename_all + let serde = parse_serde_container_attrs(&input); + // parse the input to the ValidationData struct defined above let validation_data = match ValidationData::from_derive_input(&input) { Ok(data) => data, @@ -325,7 +341,7 @@ pub fn derive_validation(input: proc_macro::TokenStream) -> proc_macro::TokenStr .unwrap() .fields .into_iter() - .map(|f| f.parsed) + .map(|f| parse_serde_attrs(f.parsed, f.original, &serde)) // skip fields with #[validate(skip)] attribute .filter(|f| if let Some(s) = f.skip { !s } else { true }) .map(|f| ValidateField { crate_name: crate_name.clone(), ..f }) @@ -423,3 +439,71 @@ pub fn derive_validation(input: proc_macro::TokenStream) -> proc_macro::TokenStr ) .into() } + +fn parse_serde_container_attrs(di: &DeriveInput) -> SerdeData { + let mut data = SerdeData::default(); + + for attr in &di.attrs { + if !attr.path().is_ident("serde") { + continue; + } + + let _ = attr.parse_nested_meta(|meta| { + if let Some(rename_all) = parse_serde_rename(&meta, "rename_all")? { + RenameRule::parse(rename_all).map(|rule| data.rename_all = rule); + } + Ok(()) + }); + } + + data +} + +fn parse_serde_attrs( + mut vf: ValidateField, + field: Field, + serde_struct: &SerdeData, +) -> ValidateField { + if serde_struct.rename_all.is_some() { + vf.rename = Some(serde_struct.rename_all.apply(&vf.ident.as_ref().unwrap().to_string())); + } + + for attr in field.attrs { + if !attr.path().is_ident("serde") { + continue; + } + + let _ = attr.parse_nested_meta(|meta| { + if let Some(rename) = parse_serde_rename(&meta, "rename")? { + vf.rename = Some(rename.value()); + } else if meta.path.is_ident("flatten") { + vf.flatten = true; + } + Ok(()) + }); + } + + vf +} + +fn parse_serde_rename(meta: &ParseNestedMeta, id: &str) -> syn::Result> { + if !meta.path.is_ident(id) { + return Ok(None); + } + + if meta.input.peek(Token![=]) { + let value = meta.value()?; + Ok(Some(value.parse()?)) + } else { + let mut res = None; + meta.parse_nested_meta(|meta| { + let value = meta.value()?; + let s: LitStr = value.parse()?; + if meta.path.is_ident("deserialize") { + res = Some(s); + } + Ok(()) + })?; + Ok(res) + } +} diff --git a/validator_derive/src/tokens/nested.rs b/validator_derive/src/tokens/nested.rs index f4840e4c..79a9417c 100644 --- a/validator_derive/src/tokens/nested.rs +++ b/validator_derive/src/tokens/nested.rs @@ -3,10 +3,17 @@ use quote::quote; pub fn nested_tokens( field_name: &proc_macro2::TokenStream, field_name_str: &str, + flatten: bool, ) -> proc_macro2::TokenStream { - quote! { - if let std::collections::hash_map::Entry::Vacant(entry) = errors.0.entry(::std::borrow::Cow::Borrowed(#field_name_str)) { - errors.merge_self(#field_name_str, (&#field_name).validate()); + if flatten { + quote! { + errors.merge_self_flatten((&#field_name).validate()); + } + } else { + quote! { + if let std::collections::hash_map::Entry::Vacant(entry) = errors.0.entry(::std::borrow::Cow::Borrowed(#field_name_str)) { + errors.merge_self(#field_name_str, (&#field_name).validate()); + } } } } diff --git a/validator_derive/src/types.rs b/validator_derive/src/types.rs index 54cbcf1b..743fc5a3 100644 --- a/validator_derive/src/types.rs +++ b/validator_derive/src/types.rs @@ -67,6 +67,10 @@ pub struct ValidateField { pub custom: Vec, pub skip: Option, pub nested: Option, + #[darling(skip)] + pub rename: Option, + #[darling(skip)] + pub flatten: bool, /// Placeholder for the crate name, filled in by the [`ValidationData`](crate::ValidationData) value. #[darling(skip)] pub crate_name: CrateName, @@ -153,12 +157,11 @@ impl ValidateField { pub fn number_options(&self) -> u8 { fn find_option(mut count: u8, ty: &syn::Type) -> u8 { if let syn::Type::Path(p) = ty { - let idents_of_path = - p.path.segments.iter().fold(String::new(), |mut acc, v| { - acc.push_str(&v.ident.to_string()); - acc.push('|'); - acc - }); + let idents_of_path = p.path.segments.iter().fold(String::new(), |mut acc, v| { + acc.push_str(&v.ident.to_string()); + acc.push('|'); + acc + }); if OPTIONS_TYPE.contains(&idents_of_path.as_str()) { count += 1; diff --git a/validator_derive_tests/tests/flatten.rs b/validator_derive_tests/tests/flatten.rs new file mode 100644 index 00000000..acfa7b3e --- /dev/null +++ b/validator_derive_tests/tests/flatten.rs @@ -0,0 +1,49 @@ +use serde::Deserialize; +use validator::Validate; + +#[test] +fn can_flatten_structs() { + #[derive(Deserialize, Validate)] + struct TestStruct { + #[validate(range(min = -5))] + field: i16, + #[validate(nested)] + inner_regular: InnerStruct, + #[serde(flatten)] + #[validate(nested)] + inner_merged: FlattenedStruct, + } + + #[derive(Deserialize, Validate)] + struct InnerStruct { + #[validate(length(max = 5))] + test_sample: String, + #[validate(range(max = 200))] + something: i64, + } + + #[derive(Deserialize, Validate)] + struct FlattenedStruct { + #[validate(length(max = 5))] + hello_world: String, + #[validate(range(max = 200))] + anything: i64, + } + + let s = TestStruct { + field: -10, + inner_regular: InnerStruct { test_sample: "abcdef".to_owned(), something: 500 }, + inner_merged: FlattenedStruct { hello_world: "abcdef".to_owned(), anything: 500 }, + }; + + let errs = s.validate().unwrap_err().0; + + assert!(errs.contains_key("field")); + assert!(errs.contains_key("inner_regular")); + assert!(errs.contains_key("hello_world")); + assert!(errs.contains_key("anything")); + + assert!(!errs.contains_key("inner_merged")); + assert!(!errs.contains_key("test_sample")); + assert!(!errs.contains_key("something")); +} diff --git a/validator_derive_tests/tests/rename.rs b/validator_derive_tests/tests/rename.rs new file mode 100644 index 00000000..bf4c6a64 --- /dev/null +++ b/validator_derive_tests/tests/rename.rs @@ -0,0 +1,140 @@ +use serde::Deserialize; +use validator::Validate; + +#[test] +fn renames_fields() { + #[derive(Deserialize, Validate)] + struct TestStruct { + #[serde(rename = "fieldNAME123")] + #[validate(range(min = -5))] + field_name: i16, + #[serde(default, skip_serializing, rename = "_SomeTest")] + #[validate(length(max = 5))] + some_test: String, + } + + let s = TestStruct { field_name: -10, some_test: "abcdef".to_owned() }; + + let err = s.validate().unwrap_err(); + let errs = err.field_errors(); + + assert!(errs.contains_key("fieldNAME123")); + assert!(errs.contains_key("_SomeTest")); + assert!(!errs.contains_key("field_name")); + assert!(!errs.contains_key("some_test")); +} + +#[test] +fn renames_fields_as_in_deserialize() { + #[derive(Deserialize, Validate)] + struct TestStruct { + #[serde(rename(serialize = "abc", deserialize = "fieldNAME123"))] + #[validate(range(min = -5))] + field_name: i16, + #[serde(default, skip_serializing, rename(deserialize = "_SomeTest"))] + #[validate(length(max = 5))] + some_test: String, + } + + let s = TestStruct { field_name: -10, some_test: "abcdef".to_owned() }; + + let err = s.validate().unwrap_err(); + let errs = err.field_errors(); + + assert!(errs.contains_key("fieldNAME123")); + assert!(errs.contains_key("_SomeTest")); + assert!(!errs.contains_key("abc")); + assert!(!errs.contains_key("field_name")); + assert!(!errs.contains_key("some_test")); +} + +#[test] +fn rename_all_camel_case_works() { + #[derive(Deserialize, Validate)] + #[serde(deny_unknown_fields, rename_all = "camelCase")] + struct TestStruct { + #[validate(range(min = -5))] + field_name: i16, + #[validate(length(max = 5))] + some_test_hello: String, + } + + let s = TestStruct { field_name: -10, some_test_hello: "abcdef".to_owned() }; + + let err = s.validate().unwrap_err(); + let errs = err.field_errors(); + + assert!(errs.contains_key("fieldName")); + assert!(errs.contains_key("someTestHello")); + assert!(!errs.contains_key("field_name")); + assert!(!errs.contains_key("some_test_hello")); +} + +#[test] +fn rename_all_camel_case_as_in_deserialize() { + #[derive(Deserialize, Validate)] + #[serde(deny_unknown_fields, rename_all(serialize = "kebab-case", deserialize = "camelCase"))] + struct TestStruct { + #[validate(range(min = -5))] + field_name: i16, + #[validate(length(max = 5))] + some_test_hello: String, + } + + let s = TestStruct { field_name: -10, some_test_hello: "abcdef".to_owned() }; + + let err = s.validate().unwrap_err(); + let errs = err.field_errors(); + + assert!(errs.contains_key("fieldName")); + assert!(errs.contains_key("someTestHello")); + assert!(!errs.contains_key("field-name")); + assert!(!errs.contains_key("some-test-hello")); + assert!(!errs.contains_key("field_name")); + assert!(!errs.contains_key("some_test_hello")); +} + +#[test] +fn rename_all_kebab_uppercase_works() { + #[derive(Deserialize, Validate)] + #[serde(rename_all = "SCREAMING-KEBAB-CASE")] + struct TestStruct { + #[validate(range(min = -5))] + field_name: i16, + #[validate(length(max = 5))] + some_test_hello: String, + } + + let s = TestStruct { field_name: -10, some_test_hello: "abcdef".to_owned() }; + + let err = s.validate().unwrap_err(); + let errs = err.field_errors(); + + assert!(errs.contains_key("FIELD-NAME")); + assert!(errs.contains_key("SOME-TEST-HELLO")); + assert!(!errs.contains_key("field_name")); + assert!(!errs.contains_key("some_test_hello")); +} + +#[test] +fn rename_all_pascal_with_custom() { + #[derive(Deserialize, Validate)] + #[serde(rename_all = "PascalCase")] + struct TestStruct { + #[validate(range(min = -5))] + field_name: i16, + #[serde(default, skip_serializing, rename = "_Some-test-123")] + #[validate(length(max = 5))] + some_test_hello: String, + } + + let s = TestStruct { field_name: -10, some_test_hello: "abcdef".to_owned() }; + + let err = s.validate().unwrap_err(); + let errs = err.field_errors(); + + assert!(errs.contains_key("FieldName")); + assert!(errs.contains_key("_Some-test-123")); + assert!(!errs.contains_key("field_name")); + assert!(!errs.contains_key("some_test_hello")); +}