Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ strum = { version = "0.27.2", features = ["derive"], default-features = fa

[dependencies]
quote = "1.0.41"
syn = { version = "2.0.106", features = ["full", "extra-traits"] }
syn = { version = "2.0.106", features = ["full", "extra-traits", "visit"] }
proc-macro2 = "1.0.101"
heck = "0.5.0"

Expand Down
128 changes: 110 additions & 18 deletions src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,120 @@ fn add_bound(generics: &mut Generics, bound: TypeParamBound) {
}
}

// Map a variant from an enum definition to how it would be used in a match
// E.g.
// * Foo -> Foo
// * Foo(Bar, Baz) -> Foo(var1, var2)
// * Foo { x: i32, y: i32 } -> Foo { x, y }
fn variant_to_unary_pat(variant: &Variant) -> TokenStream2 {
/// Generates the pattern and the corresponding expression with `.into()` calls.
/// Returns (pattern: TokenStream2, expression: TokenStream2)
fn variant_to_pat_and_into_expr(variant: &Variant) -> (TokenStream2, TokenStream2) {
let ident = &variant.ident;

match &variant.fields {
// --- 1. Named Fields (e.g., Variant { a, b }) ---
syn::Fields::Named(named) => {
// Pattern: Variant { var1, var2 }
let vars: Punctuated<Ident, Token![,]> = named
.named
.iter()
.flat_map(|it| it.ident.as_ref())
.cloned()
.collect();
let pattern = quote!(#ident { #vars });

let vars = vars.iter();
// Expression: Variant { var1: var1.into(), var2: var2.into() }
let expression = quote! {
#ident {
#(#vars: #vars.into()),*
}
};
(pattern, expression)
}

// --- 2. Unnamed Fields (e.g., Variant(var1, var2)) ---
syn::Fields::Unnamed(unnamed) => {
// Create identifiers for the variables (var0, var1, ...)
let vars: Punctuated<Ident, Token![,]> = unnamed
.unnamed
.iter()
.enumerate()
.map(|(idx, _)| format_ident!("var{idx}"))
.collect();

// Pattern: Variant(var0, var1, ...)
let pattern = quote!(#ident(#vars));

let vars = vars.iter();
// Expression: Variant(var0.into(), var1.into(), ...)
let expression = quote! {
#ident(#(#vars.into()),*)
};
(pattern, expression)
}

// --- 3. Unit Field (e.g., Variant) ---
syn::Fields::Unit => {
let pattern = quote!(#ident);
let expression = quote!(#ident);
(pattern, expression)
}
}
}

fn variant_to_pat_and_try_into_expr(
variant: &Variant,
error_type: &Ident,
) -> (TokenStream2, TokenStream2) {
let ident = &variant.ident;
let error_ident = error_type;

match &variant.fields {
// --- 1. Named Fields ---
syn::Fields::Named(named) => {
// Pattern: ParentEnum::Variant { var1, var2 }
let vars: Punctuated<Ident, Token![,]> = named.named.iter().map(snake_case).collect();
quote!(#ident{#vars})
let pattern = quote!(#ident { #vars });

// Expression: ParentEnum::Variant { var1: var1.try_into().map_err(|_| E)? }
let conversion_exprs = vars
.iter()
.map(|v| quote!(#v: #v.try_into().map_err(|_| #error_ident)?));

let expression = quote! {
#ident {
#(#conversion_exprs),*
}
};
(pattern, expression)
}

// --- 2. Unnamed Fields ---
syn::Fields::Unnamed(unnamed) => {
// Create identifiers for the variables (var0, var1, ...)
let vars: Punctuated<Ident, Token![,]> = unnamed
.unnamed
.iter()
.enumerate()
.map(|(idx, _)| format_ident!("var{idx}"))
.collect();
quote!(#ident(#vars))

// Pattern: ParentEnum::Variant(var0, var1, ...)
let pattern = quote!(#ident(#vars));

// Expression: ParentEnum::Variant(var0.try_into().map_err(|_| E)?, ...)
let conversion_exprs = vars
.iter()
.map(|v| quote!(#v.try_into().map_err(|_| #error_ident)?));

let expression = quote! {
#ident(#(#conversion_exprs),*)
};
(pattern, expression)
}

// --- 3. Unit Field ---
syn::Fields::Unit => {
let pattern = quote!(#ident);
let expression = quote!(#ident);
(pattern, expression)
}
syn::Fields::Unit => quote!(#ident),
}
}

Expand Down Expand Up @@ -121,15 +212,16 @@ impl Enum {
impl core::error::Error for #error {}
);

let pats: Vec<TokenStream2> = self.variants.iter().map(variant_to_unary_pat).collect();

let from_child_arms = pats
let into_pats = self.variants.iter().map(variant_to_pat_and_into_expr);
let try_into_pats = self
.variants
.iter()
.map(|pat| quote!(#child_ident::#pat => #parent_ident::#pat));
.map(|it| variant_to_pat_and_try_into_expr(it, &error));

let try_from_parent_arms = pats
.iter()
.map(|pat| quote!(#parent_ident::#pat => Ok(#child_ident::#pat)));
let from_child_arms = into_pats.map(|(a, b)| quote!(#child_ident::#a => #parent_ident::#b));

let try_from_parent_arms =
try_into_pats.map(|(a, b)| quote!(#parent_ident::#a => Ok(#child_ident::#b)));

let inherited_derives = self
.derives
Expand All @@ -138,7 +230,7 @@ impl Enum {

let vis = &parent.vis;

let (_child_impl, child_ty, _child_where) = child_generics.split_for_impl();
let (_child_impl, child_ty, child_where) = child_generics.split_for_impl();

let (parent_impl, parent_ty, parent_where) = parent.generics.split_for_impl();

Expand All @@ -149,7 +241,7 @@ impl Enum {
quote!(
#(#[ #attributes ])*
#(#child_attrs)*
#vis enum #child_ident #child_generics {
#vis enum #child_ident #child_generics #child_where {
#(#variants),*
}

Expand Down
170 changes: 77 additions & 93 deletions src/enum.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use crate::predicate::analyze_generics;
use crate::visitor::ParamVisitor;
use crate::{param::Param, Derive};
use alloc::{collections::BTreeSet, vec::Vec};
use proc_macro2::TokenStream;
use syn::{punctuated::Punctuated, Generics, Ident, Token, TypeParamBound, Variant};

use crate::{extractor::Extractor, iter::BoxedIter, param::Param, Derive};
use syn::visit::Visit;
use syn::{punctuated::Punctuated, Generics, Ident, Token, Variant, WherePredicate};

pub struct Enum {
pub ident: Ident,
Expand Down Expand Up @@ -34,99 +33,84 @@ impl Enum {
}

pub fn compute_generics(&mut self, parent_generics: &Generics) {
let generic_bounds: BTreeMap<Param, Vec<TypeParamBound>> = parent_generics
.type_params()
.map(|param| {
(
Param::Ident(param.ident.clone()),
param.bounds.iter().cloned().collect(),
)
})
.chain(parent_generics.lifetimes().map(|lifetime_def| {
(
Param::Lifetime(lifetime_def.lifetime.clone()),
lifetime_def
.bounds
.iter()
.cloned()
.map(TypeParamBound::Lifetime)
.collect(),
)
}))
.chain(
parent_generics
.where_clause
.iter()
.flat_map(|clause| &clause.predicates)
.flat_map(|pred| match pred {
syn::WherePredicate::Type(ty) => {
// We have to be a bit careful here. Imagine the bound
// <T as Add<U>>:: Foo
// We need to treat this as a bound on both `T` and on `U`.
let bounds: Vec<TypeParamBound> = ty.bounds.iter().cloned().collect();
ty.bounded_ty
.extract_idents()
.into_iter()
.map(move |ident| (Param::Ident(ident), bounds.clone()))
.boxed()
}
syn::WherePredicate::Lifetime(lt) => [(
Param::Lifetime(lt.lifetime.clone()),
lt.bounds
.iter()
.cloned()
.map(TypeParamBound::Lifetime)
.collect(),
)]
.into_iter()
.boxed(),
_ => panic!("Unsupported where predicate"),
}),
)
.collect();
// 1. Analyze constraints: Convert all inline bounds and where clauses
// into a list of PredicateDependency
let mut deps = analyze_generics(parent_generics);

// panic!("{generic_bounds:#?}");
// 2. Identify "Root" params: The generics explicitly used in the variants.
let mut visitor = ParamVisitor::new(parent_generics);
for variant in &self.variants {
visitor.visit_variant(variant);
}

let types = self
.variants
.iter()
.flat_map(|variant| match &variant.fields {
syn::Fields::Named(named) => named.named.iter().map(|field| &field.ty).collect(),
syn::Fields::Unnamed(unnamed) => {
unnamed.unnamed.iter().map(|field| &field.ty).collect()
}
syn::Fields::Unit => Vec::new(),
});
// Extract all of the lifetimes and idents we care about from the types.
let params = types.into_iter().flat_map(|ty| ty.extract_params());
let mut active_params: BTreeSet<Param> = visitor.found;
let mut active_predicates: Vec<WherePredicate> = Vec::new();

// The same generic may appear in multiple bounds, so we use a BTreeSet to dedup.
let relevant_params: BTreeSet<Param> = params
.flat_map(|param| param.find_relevant(&generic_bounds))
.collect();
// 3. Repeatedly iterate through dependencies. If a predicate mentions
// ANY active param, we must keep that predicate AND activate
// any other params it mentions.
let mut changed = true;
while changed {
changed = false;

self.generics = generics_subset(parent_generics, relevant_params.into_iter());
}
}
// We retain only the predicates we haven't matched yet.
deps.retain(|dep| {
// Check if this dependency touches any currently active param
let is_relevant = dep.used_params.iter().any(|p| active_params.contains(p));

/// Given a set of `Generics`, return the subset that we're interested in.
/// Expects `params` already includes all possible types/lifetimes we care
// about.
/// E.g. with generics `T: U, U, V`, this function should never be called with
/// just params of `T`; it would instead expect `T, U`.
/// In short: call `find_all_generics` first.
fn generics_subset(generics: &Generics, params: impl Iterator<Item = Param>) -> Generics {
let mut new = Generics::default();
if is_relevant {
// It is relevant: Keep the predicate
active_predicates.push(dep.predicate.clone());

for param in params {
let (generic_param, predicate) = param.find(generics);
if let Some(gp) = generic_param {
new.params.push(gp.clone());
// Activate all params used by this predicate
for p in &dep.used_params {
if active_params.insert(p.clone()) {
// If we added a NEW param, we must loop again
// to check for bounds dependent on this new param.
changed = true;
}
}
// Remove from `deps` so we don't process it again
return false;
}
true // Keep in `deps` for next pass
});
}
if let Some(pred) = predicate {
new.make_where_clause().predicates.push(pred.clone());
// 4. Construct the final Generics struct in-place
self.generics = Generics::default();

// A. Filter params and strip inline bounds
for param in &parent_generics.params {
let keep = match param {
syn::GenericParam::Type(t) => {
active_params.contains(&Param::Ident(t.ident.clone()))
}
syn::GenericParam::Lifetime(l) => {
active_params.contains(&Param::Lifetime(l.lifetime.clone()))
}
syn::GenericParam::Const(c) => {
active_params.contains(&Param::Ident(c.ident.clone()))
}
};

if keep {
let mut p = param.clone();
// CRITICAL: We clear inline bounds here because `analyze_generics`
// has already converted them into predicates. If we don't clear them,
// we will have duplicates (once in <> and once in where clause).
match &mut p {
syn::GenericParam::Type(t) => t.bounds.clear(),
syn::GenericParam::Lifetime(l) => l.bounds.clear(),
_ => {}
}
self.generics.params.push(p);
}
}
}

new
// B. Append the collected predicates to the where clause
if !active_predicates.is_empty() {
let where_clause = self.generics.make_where_clause();
where_clause.predicates.extend(active_predicates);
}
}
}
Loading