From 270cc62257686e13b0c20249899d71c2c2b98c04 Mon Sep 17 00:00:00 2001 From: yukang Date: Tue, 7 Apr 2026 17:17:06 +0800 Subject: [PATCH] refactor: simplify fn pointer cast suggestion logic --- .../traits/fulfillment_errors.rs | 55 +++-------------- .../src/error_reporting/traits/suggestions.rs | 59 ++++++++++++++++++- 2 files changed, 66 insertions(+), 48 deletions(-) diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs index 43ab4a64fbedc..04c9edc25d172 100644 --- a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs +++ b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs @@ -39,9 +39,7 @@ use rustc_span::{BytePos, DUMMY_SP, STDLIB_STABLE_CRATES, Span, Symbol, sym}; use tracing::{debug, instrument}; use super::suggestions::get_explanation_based_on_obligation; -use super::{ - ArgKind, CandidateSimilarity, FindExprBySpan, GetSafeTransmuteErrorAndReason, ImplCandidate, -}; +use super::{ArgKind, CandidateSimilarity, GetSafeTransmuteErrorAndReason, ImplCandidate}; use crate::error_reporting::TypeErrCtxt; use crate::error_reporting::infer::TyCategory; use crate::error_reporting::traits::report_dyn_incompatibility; @@ -452,50 +450,13 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> { self.suggest_dereferencing_index(&obligation, &mut err, leaf_trait_predicate); suggested |= self.suggest_dereferences(&obligation, &mut err, leaf_trait_predicate); suggested |= self.suggest_fn_call(&obligation, &mut err, leaf_trait_predicate); - let impl_candidates = self.find_similar_impl_candidates(leaf_trait_predicate); - suggested = if let &[cand] = &impl_candidates[..] { - let cand = cand.trait_ref; - if let (ty::FnPtr(..), ty::FnDef(..)) = - (cand.self_ty().kind(), main_trait_predicate.self_ty().skip_binder().kind()) - { - // Wrap method receivers and `&`-references in parens - let suggestion = if self.tcx.sess.source_map().span_followed_by(span, ".").is_some() { - vec![ - (span.shrink_to_lo(), format!("(")), - (span.shrink_to_hi(), format!(" as {})", cand.self_ty())), - ] - } else if let Some(body) = self.tcx.hir_maybe_body_owned_by(obligation.cause.body_id) { - let mut expr_finder = FindExprBySpan::new(span, self.tcx); - expr_finder.visit_expr(body.value); - if let Some(expr) = expr_finder.result && - let hir::ExprKind::AddrOf(_, _, expr) = expr.kind { - vec![ - (expr.span.shrink_to_lo(), format!("(")), - (expr.span.shrink_to_hi(), format!(" as {})", cand.self_ty())), - ] - } else { - vec![(span.shrink_to_hi(), format!(" as {}", cand.self_ty()))] - } - } else { - vec![(span.shrink_to_hi(), format!(" as {}", cand.self_ty()))] - }; - let trait_ = self.tcx.short_string(cand.print_trait_sugared(), err.long_ty_path()); - let ty = self.tcx.short_string(cand.self_ty(), err.long_ty_path()); - err.multipart_suggestion( - format!( - "the trait `{trait_}` is implemented for fn pointer \ - `{ty}`, try casting using `as`", - ), - suggestion, - Applicability::MaybeIncorrect, - ); - true - } else { - false - } - } else { - false - } || suggested; + suggested |= self.suggest_cast_to_fn_pointer( + &obligation, + &mut err, + leaf_trait_predicate, + main_trait_predicate, + span, + ); suggested |= self.suggest_remove_reference(&obligation, &mut err, leaf_trait_predicate); suggested |= self.suggest_semicolon_removal( diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs index 62f6b87c9e98c..805a00d198606 100644 --- a/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs +++ b/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs @@ -29,7 +29,8 @@ use rustc_middle::ty::adjustment::{Adjust, DerefAdjustKind}; use rustc_middle::ty::error::TypeError; use rustc_middle::ty::print::{ PrintPolyTraitPredicateExt as _, PrintPolyTraitRefExt, PrintTraitPredicateExt as _, - with_forced_trimmed_paths, with_no_trimmed_paths, with_types_for_suggestion, + PrintTraitRefExt as _, with_forced_trimmed_paths, with_no_trimmed_paths, + with_types_for_suggestion, }; use rustc_middle::ty::{ self, AdtKind, GenericArgs, InferTy, IsSuggestable, Ty, TyCtxt, TypeFoldable, TypeFolder, @@ -1142,6 +1143,62 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> { true } + pub(super) fn suggest_cast_to_fn_pointer( + &self, + obligation: &PredicateObligation<'tcx>, + err: &mut Diag<'_>, + leaf_trait_predicate: ty::PolyTraitPredicate<'tcx>, + main_trait_predicate: ty::PolyTraitPredicate<'tcx>, + span: Span, + ) -> bool { + let &[candidate] = &self.find_similar_impl_candidates(leaf_trait_predicate)[..] else { + return false; + }; + let candidate = candidate.trait_ref; + + if !matches!( + (candidate.self_ty().kind(), main_trait_predicate.self_ty().skip_binder().kind(),), + (ty::FnPtr(..), ty::FnDef(..)) + ) { + return false; + } + + let parenthesized_cast = |span: Span| { + vec![ + (span.shrink_to_lo(), "(".to_string()), + (span.shrink_to_hi(), format!(" as {})", candidate.self_ty())), + ] + }; + // Wrap method receivers and `&`-references in parens. + let suggestion = if self.tcx.sess.source_map().span_followed_by(span, ".").is_some() { + parenthesized_cast(span) + } else if let Some(body) = self.tcx.hir_maybe_body_owned_by(obligation.cause.body_id) { + let mut expr_finder = FindExprBySpan::new(span, self.tcx); + expr_finder.visit_expr(body.value); + if let Some(expr) = expr_finder.result + && let hir::ExprKind::AddrOf(_, _, expr) = expr.kind + { + parenthesized_cast(expr.span) + } else { + vec![(span.shrink_to_hi(), format!(" as {}", candidate.self_ty()))] + } + } else { + vec![(span.shrink_to_hi(), format!(" as {}", candidate.self_ty()))] + }; + + let trait_ = self.tcx.short_string(candidate.print_trait_sugared(), err.long_ty_path()); + let self_ty = self.tcx.short_string(candidate.self_ty(), err.long_ty_path()); + err.multipart_suggestion( + format!( + "the trait `{trait_}` is implemented for fn pointer \ + `{self_ty}`, try casting using `as`", + ), + suggestion, + Applicability::MaybeIncorrect, + ); + true + } + pub(super) fn check_for_binding_assigned_block_without_tail_expression( &self, obligation: &PredicateObligation<'tcx>,