Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions validator/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ impl ValidationErrors {
}
}

pub fn merge_self_flatten(
&mut self,
child: Result<(), ValidationErrors>,
) -> &mut ValidationErrors {
match child {
Ok(()) => self,
Err(errors) => {
self.0.extend(errors.0);
self
}
}
}

/// Returns the combined outcome of a struct's validation result along with the nested
/// validation result for one of its fields.
pub fn merge(
Expand Down
66 changes: 66 additions & 0 deletions validator_derive/src/case.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Adapted from serde's code:
// https://github.com/serde-rs/serde/blob/d1790205/serde_derive/src/internals/case.rs

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum RenameRule {
#[default]
None,
LowerCase,
UpperCase,
PascalCase,
CamelCase,
SnakeCase,
ScreamingSnakeCase,
KebabCase,
ScreamingKebabCase,
}

impl RenameRule {
pub fn parse(s: syn::LitStr) -> Option<Self> {
match &*s.value() {
"lowercase" => Some(Self::LowerCase),
"UPPERCASE" => Some(Self::UpperCase),
"PascalCase" => Some(Self::PascalCase),
"camelCase" => Some(Self::CamelCase),
"snake_case" => Some(Self::SnakeCase),
"SCREAMING_SNAKE_CASE" => Some(Self::ScreamingSnakeCase),
"kebab-case" => Some(Self::KebabCase),
"SCREAMING-KEBAB-CASE" => Some(Self::ScreamingKebabCase),
_ => None,
}
}

pub fn apply(&self, orig: &str) -> String {
match self {
Self::None | Self::LowerCase | Self::SnakeCase => orig.to_owned(),
Self::ScreamingSnakeCase => orig.to_ascii_uppercase(),
Self::UpperCase => orig.to_ascii_lowercase(),
Self::PascalCase => pascal_case(orig, true),
Self::CamelCase => pascal_case(orig, false),
Self::KebabCase => orig.replace("_", "-"),
Self::ScreamingKebabCase => orig.to_ascii_uppercase().replace("_", "-"),
}
}

pub fn is_some(&self) -> bool {
!matches!(self, Self::None)
}
}

/// Converts a snake_case field name into PascalCase or camelCase.
/// `cap_start` specifies if the first character is capitalized (as in Pascal).
fn pascal_case(orig: &str, cap_start: bool) -> String {
let mut pascal = String::new();
let mut capitalize = cap_start;
for ch in orig.chars() {
if ch == '_' {
capitalize = true;
} else if capitalize {
pascal.push(ch.to_ascii_uppercase());
capitalize = false;
} else {
pascal.push(ch);
}
}
pascal
}
96 changes: 90 additions & 6 deletions validator_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ use darling::util::{Override, WithOriginal};
use darling::FromDeriveInput;
use proc_macro_error2::{abort, proc_macro_error};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, DeriveInput, Field, GenericParam, Path, PathArguments};
use syn::meta::ParseNestedMeta;
use syn::{
parse_macro_input, DeriveInput, Field, GenericParam, LitStr, Path, PathArguments, Token,
};

use case::RenameRule;
use tokens::cards::credit_card_tokens;
use tokens::contains::contains_tokens;
use tokens::custom::custom_tokens;
Expand All @@ -23,14 +27,18 @@ use tokens::url::url_tokens;
use types::*;
use utils::{quote_use_stmts, CrateName};

mod case;
mod tokens;
mod types;
mod utils;

impl ToTokens for ValidateField {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let field_name = self.ident.clone().unwrap();
let field_name_str = self.ident.clone().unwrap().to_string();
let field_name_str = match &self.rename {
Some(rename) => rename.clone(),
None => self.ident.as_ref().unwrap().to_string(),
};

let type_name = self.ty.to_token_stream().to_string();
let is_number = NUMBER_TYPES.contains(&type_name);
Expand Down Expand Up @@ -211,7 +219,7 @@ impl ToTokens for ValidateField {

let nested = if let Some(n) = self.nested {
if n {
wrapper_closure(nested_tokens(&actual_field, &field_name_str))
wrapper_closure(nested_tokens(&actual_field, &field_name_str, self.flatten))
} else {
quote!()
}
Expand Down Expand Up @@ -258,6 +266,11 @@ struct ValidationData {
crate_name: CrateName,
}

#[derive(Debug, Default, Clone)]
struct SerdeData {
rename_all: RenameRule,
}

impl ValidationData {
fn validate(self) -> darling::Result<Self> {
if let Some(context) = &self.context {
Expand All @@ -280,8 +293,7 @@ impl ValidationData {
}

if let Data::Struct(fields) = &self.data {
let original_fields: Vec<&Field> =
fields.fields.iter().map(|f| &f.original).collect();
let original_fields: Vec<&Field> = fields.fields.iter().map(|f| &f.original).collect();
for f in &fields.fields {
f.parsed.validate(&self.ident, &original_fields, &f.original);
}
Expand All @@ -296,6 +308,10 @@ impl ValidationData {
pub fn derive_validation(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input: DeriveInput = parse_macro_input!(input);

// parse struct-level serde attributes
// i.e. rename_all
let serde = parse_serde_container_attrs(&input);

// parse the input to the ValidationData struct defined above
let validation_data = match ValidationData::from_derive_input(&input) {
Ok(data) => data,
Expand Down Expand Up @@ -325,7 +341,7 @@ pub fn derive_validation(input: proc_macro::TokenStream) -> proc_macro::TokenStr
.unwrap()
.fields
.into_iter()
.map(|f| f.parsed)
.map(|f| parse_serde_attrs(f.parsed, f.original, &serde))
// skip fields with #[validate(skip)] attribute
.filter(|f| if let Some(s) = f.skip { !s } else { true })
.map(|f| ValidateField { crate_name: crate_name.clone(), ..f })
Expand Down Expand Up @@ -423,3 +439,71 @@ pub fn derive_validation(input: proc_macro::TokenStream) -> proc_macro::TokenStr
)
.into()
}

fn parse_serde_container_attrs(di: &DeriveInput) -> SerdeData {
let mut data = SerdeData::default();

for attr in &di.attrs {
if !attr.path().is_ident("serde") {
continue;
}

let _ = attr.parse_nested_meta(|meta| {
if let Some(rename_all) = parse_serde_rename(&meta, "rename_all")? {
RenameRule::parse(rename_all).map(|rule| data.rename_all = rule);
}
Ok(())
});
}

data
}

fn parse_serde_attrs(
mut vf: ValidateField,
field: Field,
serde_struct: &SerdeData,
) -> ValidateField {
if serde_struct.rename_all.is_some() {
vf.rename = Some(serde_struct.rename_all.apply(&vf.ident.as_ref().unwrap().to_string()));
}

for attr in field.attrs {
if !attr.path().is_ident("serde") {
continue;
}

let _ = attr.parse_nested_meta(|meta| {
if let Some(rename) = parse_serde_rename(&meta, "rename")? {
vf.rename = Some(rename.value());
} else if meta.path.is_ident("flatten") {
vf.flatten = true;
}
Ok(())
});
}

vf
}

fn parse_serde_rename(meta: &ParseNestedMeta, id: &str) -> syn::Result<Option<LitStr>> {
if !meta.path.is_ident(id) {
return Ok(None);
}

if meta.input.peek(Token![=]) {
let value = meta.value()?;
Ok(Some(value.parse()?))
} else {
let mut res = None;
meta.parse_nested_meta(|meta| {
let value = meta.value()?;
let s: LitStr = value.parse()?;
Copy link
Author

@DarkCat09 DarkCat09 Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parsing = & LitStr tokens is required no matter whether we get deserialize ident or not, because otherwise we leave the parser at the same position and don't get valid next attribute tokens on the next loop iteration

if meta.path.is_ident("deserialize") {
res = Some(s);
}
Ok(())
})?;
Ok(res)
}
}
13 changes: 10 additions & 3 deletions validator_derive/src/tokens/nested.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@ use quote::quote;
pub fn nested_tokens(
field_name: &proc_macro2::TokenStream,
field_name_str: &str,
flatten: bool,
) -> proc_macro2::TokenStream {
quote! {
if let std::collections::hash_map::Entry::Vacant(entry) = errors.0.entry(::std::borrow::Cow::Borrowed(#field_name_str)) {
errors.merge_self(#field_name_str, (&#field_name).validate());
if flatten {
quote! {
errors.merge_self_flatten((&#field_name).validate());
}
} else {
quote! {
if let std::collections::hash_map::Entry::Vacant(entry) = errors.0.entry(::std::borrow::Cow::Borrowed(#field_name_str)) {
errors.merge_self(#field_name_str, (&#field_name).validate());
}
}
}
}
15 changes: 9 additions & 6 deletions validator_derive/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ pub struct ValidateField {
pub custom: Vec<Custom>,
pub skip: Option<bool>,
pub nested: Option<bool>,
#[darling(skip)]
pub rename: Option<String>,
#[darling(skip)]
pub flatten: bool,
/// Placeholder for the crate name, filled in by the [`ValidationData`](crate::ValidationData) value.
#[darling(skip)]
pub crate_name: CrateName,
Expand Down Expand Up @@ -153,12 +157,11 @@ impl ValidateField {
pub fn number_options(&self) -> u8 {
fn find_option(mut count: u8, ty: &syn::Type) -> u8 {
if let syn::Type::Path(p) = ty {
let idents_of_path =
p.path.segments.iter().fold(String::new(), |mut acc, v| {
acc.push_str(&v.ident.to_string());
acc.push('|');
acc
});
let idents_of_path = p.path.segments.iter().fold(String::new(), |mut acc, v| {
acc.push_str(&v.ident.to_string());
acc.push('|');
acc
});

if OPTIONS_TYPE.contains(&idents_of_path.as_str()) {
count += 1;
Expand Down
49 changes: 49 additions & 0 deletions validator_derive_tests/tests/flatten.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use serde::Deserialize;
use validator::Validate;

#[test]
fn can_flatten_structs() {
#[derive(Deserialize, Validate)]
struct TestStruct {
#[validate(range(min = -5))]
field: i16,
#[validate(nested)]
inner_regular: InnerStruct,
#[serde(flatten)]
#[validate(nested)]
inner_merged: FlattenedStruct,
}

#[derive(Deserialize, Validate)]
struct InnerStruct {
#[validate(length(max = 5))]
test_sample: String,
#[validate(range(max = 200))]
something: i64,
}

#[derive(Deserialize, Validate)]
struct FlattenedStruct {
#[validate(length(max = 5))]
hello_world: String,
#[validate(range(max = 200))]
anything: i64,
}

let s = TestStruct {
field: -10,
inner_regular: InnerStruct { test_sample: "abcdef".to_owned(), something: 500 },
inner_merged: FlattenedStruct { hello_world: "abcdef".to_owned(), anything: 500 },
};

let errs = s.validate().unwrap_err().0;

assert!(errs.contains_key("field"));
assert!(errs.contains_key("inner_regular"));
assert!(errs.contains_key("hello_world"));
assert!(errs.contains_key("anything"));

assert!(!errs.contains_key("inner_merged"));
assert!(!errs.contains_key("test_sample"));
assert!(!errs.contains_key("something"));
}
Loading