diff --git a/tests/ui/fail/loop_invariant_trait_self.rs b/tests/ui/fail/loop_invariant_trait_self.rs new file mode 100644 index 00000000..914a9214 --- /dev/null +++ b/tests/ui/fail/loop_invariant_trait_self.rs @@ -0,0 +1,59 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off +//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper THRUST_SOLVER_TIMEOUT_SECS=60 + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { + unimplemented!() +} + +#[thrust_macros::context] +trait Gauge { + #[thrust_macros::predicate] + fn invariant(x: i32) -> bool; + + fn update(&mut self) -> i32; + + #[thrust_macros::invariant_context] + fn run(&mut self) -> i32 { + let mut state = 0; + while rand() == 0 { + state = self.update(); + thrust_macros::invariant!(|state: i32| Self::invariant(state)); + } + state + } +} + +#[derive(PartialEq)] +struct Counter { + value: i32, +} + +impl thrust_models::Model for Counter { + type Ty = Counter; +} + +#[thrust_macros::context] +impl Gauge for Counter { + #[thrust_macros::predicate] + fn invariant(x: i32) -> bool { + "(>= x 0)"; true + } + + fn update(&mut self) -> i32 { + if self.value < 0 { + self.value *= -1; + } else { + self.value -= 1; + } + self.value + } +} + +fn main() { + let mut c = Counter { value: 0 }; + assert!(c.run() >= 0); +} diff --git a/tests/ui/pass/loop_invariant_trait_self.rs b/tests/ui/pass/loop_invariant_trait_self.rs new file mode 100644 index 00000000..5047710b --- /dev/null +++ b/tests/ui/pass/loop_invariant_trait_self.rs @@ -0,0 +1,59 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off +//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper THRUST_SOLVER_TIMEOUT_SECS=60 + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { + unimplemented!() +} + +#[thrust_macros::context] +trait Gauge { + #[thrust_macros::predicate] + fn invariant(x: i32) -> bool; + + fn update(&mut self) -> i32; + + #[thrust_macros::invariant_context] + fn run(&mut self) -> i32 { + let mut state = 0; + while rand() == 0 { + state = self.update(); + thrust_macros::invariant!(|state: i32| Self::invariant(state)); + } + state + } +} + +#[derive(PartialEq)] +struct Counter { + value: i32, +} + +impl thrust_models::Model for Counter { + type Ty = Counter; +} + +#[thrust_macros::context] +impl Gauge for Counter { + #[thrust_macros::predicate] + fn invariant(x: i32) -> bool { + "(>= x 0)"; true + } + + fn update(&mut self) -> i32 { + if self.value < 0 { + self.value *= -1; + } else { + self.value += 1; + } + self.value + } +} + +fn main() { + let mut c = Counter { value: 0 }; + assert!(c.run() >= 0); +} diff --git a/thrust-macros/src/invariant.rs b/thrust-macros/src/invariant.rs index a400da85..24caf754 100644 --- a/thrust-macros/src/invariant.rs +++ b/thrust-macros/src/invariant.rs @@ -165,20 +165,46 @@ fn expand_invariant( def_wheres.extend(type_lowering.model_where_predicates()); - // `Self` in a method context: rewrite it to a synthetic generic, then pass - // the real `Self` via turbofish (legal in expression position). + let mut body = closure.body.clone(); + + // `Self` in a method context: rewrite it to a synthetic generic everywhere + // it reaches the formula function — parameters, body, and the propagated + // where-clause predicates — then pass the real `Self` via turbofish (legal + // in expression position). if crate::tokens_contain_ident(&closure.to_token_stream(), "Self") { let synth: syn::Ident = format_ident!("__ThrustSelf"); + def_wheres.push(syn::parse_quote!(#synth: ?Sized)); + + let mut rewriter = SelfRewriter { synth: &synth }; for param in &mut fn_params { - SelfRewriter { synth: &synth }.visit_fn_arg_mut(param); + rewriter.visit_fn_arg_mut(param); + } + rewriter.visit_expr_mut(&mut body); + for pred in &mut def_wheres { + rewriter.visit_where_predicate_mut(pred); } def_params.push(quote!(#synth)); def_wheres.extend(type_lowering.model_where_predicates_for(&synth)); + // Mirror the host's implicit `Self: Trait` bound onto the synthetic + // generic so trait associated types (`Self::Item`) and predicates + // (`Self::step`) remain resolvable on it. + if let Some(FnOuterItem::ItemTrait(item_trait)) = + context.and_then(|context| context.outer.as_ref()) + { + let trait_ident = &item_trait.ident; + let (_, ty_generics, _) = item_trait.generics.split_for_impl(); + def_wheres.push(syn::parse_quote!(#synth: #trait_ident #ty_generics)); + } turbofish_args.push(quote!(Self)); + + // Rewriting `Self` to the synthetic generic can yield predicates that + // duplicate the synthetic generic's own `Model` bounds; drop the dups. + let mut seen = std::collections::HashSet::new(); + def_wheres.retain(|pred| seen.insert(pred.to_token_stream().to_string())); } let model_ty_params = type_lowering.lower_params(&fn_params); - let body = &closure.body; + let body = &body; let id = COUNTER.fetch_add(1, Ordering::Relaxed); let name = format_ident!("_thrust_invariant_{}", id); @@ -218,9 +244,9 @@ struct SelfRewriter<'a> { impl VisitMut for SelfRewriter<'_> { fn visit_path_mut(&mut self, path: &mut syn::Path) { syn::visit_mut::visit_path_mut(self, path); - if path.leading_colon.is_none() - && path.segments.len() == 1 - && path.segments[0].ident == "Self" + // Rewrite the leading `Self` of any path, covering both the bare type + // `Self` and qualified paths such as `Self::Item` / `Self::step`. + if path.leading_colon.is_none() && path.segments.first().is_some_and(|s| s.ident == "Self") { path.segments[0].ident = self.synth.clone(); }