Skip to content

Commit ad7c8f0

Browse files
Port rustc_autodiff to the attribute parsers
1 parent 9719610 commit ad7c8f0

File tree

12 files changed

+141
-151
lines changed

12 files changed

+141
-151
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
use std::str::FromStr;
2+
3+
use rustc_ast::LitKind;
4+
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
5+
use rustc_feature::{AttributeTemplate, template};
6+
use rustc_hir::attrs::{AttributeKind, RustcAutodiff};
7+
use rustc_hir::{MethodKind, Target};
8+
use rustc_span::{Symbol, sym};
9+
use thin_vec::ThinVec;
10+
11+
use crate::attributes::prelude::Allow;
12+
use crate::attributes::{AttributeOrder, OnDuplicate, SingleAttributeParser};
13+
use crate::context::{AcceptContext, Stage};
14+
use crate::parser::{ArgParser, MetaItemOrLitParser};
15+
use crate::target_checking::AllowedTargets;
16+
17+
pub(crate) struct RustcAutodiffParser;
18+
19+
impl<S: Stage> SingleAttributeParser<S> for RustcAutodiffParser {
20+
const PATH: &[Symbol] = &[sym::rustc_autodiff];
21+
const ATTRIBUTE_ORDER: AttributeOrder = AttributeOrder::KeepInnermost;
22+
const ON_DUPLICATE: OnDuplicate<S> = OnDuplicate::Error;
23+
const ALLOWED_TARGETS: AllowedTargets = AllowedTargets::AllowList(&[
24+
Allow(Target::Fn),
25+
Allow(Target::Method(MethodKind::Inherent)),
26+
Allow(Target::Method(MethodKind::Trait { body: true })),
27+
Allow(Target::Method(MethodKind::TraitImpl)),
28+
]);
29+
const TEMPLATE: AttributeTemplate = template!(
30+
List: &["MODE", "WIDTH", "INPUT_ACTIVITIES", "OUTPUT_ACTIVITY"],
31+
"https://doc.rust-lang.org/std/autodiff/index.html"
32+
);
33+
34+
fn convert(cx: &mut AcceptContext<'_, '_, S>, args: &ArgParser) -> Option<AttributeKind> {
35+
let list = match args {
36+
ArgParser::NoArgs => return Some(AttributeKind::RustcAutodiff(None)),
37+
ArgParser::List(list) => list,
38+
ArgParser::NameValue(_) => {
39+
cx.expected_list_or_no_args(cx.attr_span);
40+
return None;
41+
}
42+
};
43+
44+
let mut items = list.mixed().peekable();
45+
46+
// Parse name
47+
let Some(mode) = items.next() else {
48+
cx.expected_at_least_one_argument(list.span);
49+
return None;
50+
};
51+
let Some(mode) = mode.meta_item() else {
52+
cx.expected_identifier(mode.span());
53+
return None;
54+
};
55+
let Ok(()) = mode.args().no_args() else {
56+
cx.expected_identifier(mode.span());
57+
return None;
58+
};
59+
let Some(mode) = mode.path().word() else {
60+
cx.expected_identifier(mode.span());
61+
return None;
62+
};
63+
let Ok(mode) = DiffMode::from_str(mode.as_str()) else {
64+
cx.expected_specific_argument(mode.span, DiffMode::all_modes());
65+
return None;
66+
};
67+
68+
// Parse width
69+
let width = if let Some(width) = items.peek()
70+
&& let MetaItemOrLitParser::Lit(width) = width
71+
&& let LitKind::Int(width, _) = width.kind
72+
&& let Ok(width) = width.0.try_into()
73+
{
74+
_ = items.next();
75+
width
76+
} else {
77+
1
78+
};
79+
80+
// Parse activities
81+
let mut activities = ThinVec::new();
82+
for activity in items {
83+
let MetaItemOrLitParser::MetaItemParser(activity) = activity else {
84+
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
85+
return None;
86+
};
87+
let Ok(()) = activity.args().no_args() else {
88+
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
89+
return None;
90+
};
91+
let Some(activity) = activity.path().word() else {
92+
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
93+
return None;
94+
};
95+
let Ok(activity) = DiffActivity::from_str(activity.as_str()) else {
96+
cx.expected_specific_argument(activity.span, DiffActivity::all_activities());
97+
return None;
98+
};
99+
100+
activities.push(activity);
101+
}
102+
let Some(ret_activity) = activities.pop() else {
103+
cx.expected_specific_argument(
104+
list.span.with_lo(list.span.hi()),
105+
DiffActivity::all_activities(),
106+
);
107+
return None;
108+
};
109+
110+
Some(AttributeKind::RustcAutodiff(Some(Box::new(RustcAutodiff {
111+
mode,
112+
width,
113+
input_activity: activities,
114+
ret_activity,
115+
}))))
116+
}
117+
}

compiler/rustc_attr_parsing/src/attributes/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use crate::target_checking::AllowedTargets;
3030
mod prelude;
3131

3232
pub(crate) mod allow_unstable;
33+
pub(crate) mod autodiff;
3334
pub(crate) mod body;
3435
pub(crate) mod cfg;
3536
pub(crate) mod cfg_select;

compiler/rustc_attr_parsing/src/context.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use rustc_span::{ErrorGuaranteed, Span, Symbol};
1919
use crate::AttributeParser;
2020
// Glob imports to avoid big, bitrotty import lists
2121
use crate::attributes::allow_unstable::*;
22+
use crate::attributes::autodiff::*;
2223
use crate::attributes::body::*;
2324
use crate::attributes::cfi_encoding::*;
2425
use crate::attributes::codegen_attrs::*;
@@ -202,6 +203,7 @@ attribute_parsers!(
202203
Single<ReexportTestHarnessMainParser>,
203204
Single<RustcAbiParser>,
204205
Single<RustcAllocatorZeroedVariantParser>,
206+
Single<RustcAutodiffParser>,
205207
Single<RustcBuiltinMacroParser>,
206208
Single<RustcDefPath>,
207209
Single<RustcDeprecatedSafe2024Parser>,

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@ use rustc_abi::{
66
Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size, WrappingRange,
77
};
88
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
9-
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
109
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
1110
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
1211
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
1312
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
1413
use rustc_codegen_ssa::traits::*;
1514
use rustc_data_structures::assert_matches;
1615
use rustc_hir::def_id::LOCAL_CRATE;
17-
use rustc_hir::{self as hir};
16+
use rustc_hir::{self as hir, find_attr};
1817
use rustc_middle::mir::BinOp;
1918
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
2019
use rustc_middle::ty::offload_meta::OffloadMetadata;
@@ -1367,7 +1366,9 @@ fn codegen_autodiff<'ll, 'tcx>(
13671366
let val_arr = get_args_from_tuple(bx, args[2], fn_diff);
13681367
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
13691368

1370-
let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else {
1369+
let Some(Some(mut diff_attrs)) =
1370+
find_attr!(tcx, fn_diff.def_id(), RustcAutodiff(attr) => attr.as_ref().map(Clone::clone))
1371+
else {
13711372
bug!("could not find autodiff attrs")
13721373
};
13731374

@@ -1389,7 +1390,7 @@ fn codegen_autodiff<'ll, 'tcx>(
13891390
&diff_symbol,
13901391
llret_ty,
13911392
&val_arr,
1392-
diff_attrs.clone(),
1393+
&diff_attrs,
13931394
result,
13941395
fnc_tree,
13951396
);

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
use std::str::FromStr;
2-
31
use rustc_abi::{Align, ExternAbi};
4-
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
5-
use rustc_ast::{LitKind, MetaItem, MetaItemInner};
62
use rustc_hir::attrs::{
73
AttributeKind, EiiImplResolution, InlineAttr, Linkage, RtsanSetting, UsedBy,
84
};
@@ -14,7 +10,6 @@ use rustc_middle::middle::codegen_fn_attrs::{
1410
};
1511
use rustc_middle::mir::mono::Visibility;
1612
use rustc_middle::query::Providers;
17-
use rustc_middle::span_bug;
1813
use rustc_middle::ty::{self as ty, TyCtxt};
1914
use rustc_session::lint;
2015
use rustc_session::parse::feature_err;
@@ -614,116 +609,6 @@ fn inherited_align<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Align> {
614609
tcx.codegen_fn_attrs(tcx.trait_item_of(def_id)?).alignment
615610
}
616611

617-
/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)]
618-
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
619-
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
620-
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
621-
/// panic, unless we introduced a bug when parsing the autodiff macro.
622-
//FIXME(jdonszelmann): put in the main loop. No need to have two..... :/ Let's do that when we make autodiff parsed.
623-
pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
624-
#[allow(deprecated)]
625-
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
626-
627-
let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::<Vec<_>>();
628-
629-
// check for exactly one autodiff attribute on placeholder functions.
630-
// There should only be one, since we generate a new placeholder per ad macro.
631-
let attr = match &attrs[..] {
632-
[] => return None,
633-
[attr] => attr,
634-
_ => {
635-
span_bug!(attrs[1].span(), "cg_ssa: rustc_autodiff should only exist once per source");
636-
}
637-
};
638-
639-
let list = attr.meta_item_list().unwrap_or_default();
640-
641-
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
642-
if list.is_empty() {
643-
return Some(AutoDiffAttrs::source());
644-
}
645-
646-
let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else {
647-
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities");
648-
};
649-
let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode {
650-
p1.segments.first().unwrap().ident
651-
} else {
652-
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode");
653-
};
654-
655-
// parse mode
656-
let mode = match mode.as_str() {
657-
"Forward" => DiffMode::Forward,
658-
"Reverse" => DiffMode::Reverse,
659-
_ => {
660-
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
661-
}
662-
};
663-
664-
let width: u32 = match width_meta {
665-
MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => {
666-
let w = p1.segments.first().unwrap().ident;
667-
match w.as_str().parse() {
668-
Ok(val) => val,
669-
Err(_) => {
670-
span_bug!(w.span, "rustc_autodiff width should fit u32");
671-
}
672-
}
673-
}
674-
MetaItemInner::Lit(lit) => {
675-
if let LitKind::Int(val, _) = lit.kind {
676-
match val.get().try_into() {
677-
Ok(val) => val,
678-
Err(_) => {
679-
span_bug!(lit.span, "rustc_autodiff width should fit u32");
680-
}
681-
}
682-
} else {
683-
span_bug!(lit.span, "rustc_autodiff width should be an integer");
684-
}
685-
}
686-
};
687-
688-
// First read the ret symbol from the attribute
689-
let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity else {
690-
span_bug!(attr.span(), "rustc_autodiff attribute must contain the return activity");
691-
};
692-
let ret_symbol = p1.segments.first().unwrap().ident;
693-
694-
// Then parse it into an actual DiffActivity
695-
let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
696-
span_bug!(ret_symbol.span, "invalid return activity");
697-
};
698-
699-
// Now parse all the intermediate (input) activities
700-
let mut arg_activities: Vec<DiffActivity> = vec![];
701-
for arg in input_activities {
702-
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p2, .. }) = arg {
703-
match p2.segments.first() {
704-
Some(x) => x.ident,
705-
None => {
706-
span_bug!(
707-
arg.span(),
708-
"rustc_autodiff attribute must contain the input activity"
709-
);
710-
}
711-
}
712-
} else {
713-
span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity");
714-
};
715-
716-
match DiffActivity::from_str(arg_symbol.as_str()) {
717-
Ok(arg_activity) => arg_activities.push(arg_activity),
718-
Err(_) => {
719-
span_bug!(arg_symbol.span, "invalid input activity");
720-
}
721-
}
722-
}
723-
724-
Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
725-
}
726-
727612
pub(crate) fn provide(providers: &mut Providers) {
728613
*providers = Providers {
729614
codegen_fn_attrs,

compiler/rustc_hir/src/attrs/data_structures.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,9 @@ pub enum AttributeKind {
12711271
/// Represents `#[rustc_as_ptr]` (used by the `dangling_pointers_from_temporaries` lint).
12721272
RustcAsPtr(Span),
12731273

1274+
/// Represents `#[rustc_autodiff]`.
1275+
RustcAutodiff(Option<Box<RustcAutodiff>>),
1276+
12741277
/// Represents `#[rustc_default_body_unstable]`.
12751278
RustcBodyStability {
12761279
stability: DefaultBodyStability,

compiler/rustc_hir/src/attrs/encode_cross_crate.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ impl AttributeKind {
100100
RustcAllowConstFnUnstable(..) => No,
101101
RustcAllowIncoherentImpl(..) => No,
102102
RustcAsPtr(..) => Yes,
103+
RustcAutodiff(..) => Yes,
103104
RustcBodyStability { .. } => No,
104105
RustcBuiltinMacro { .. } => Yes,
105106
RustcCaptureAnalysis => No,

compiler/rustc_hir/src/attrs/pretty_printing.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use rustc_abi::Align;
66
use rustc_ast::ast::{Path, join_path_idents};
77
use rustc_ast::attr::data_structures::CfgEntry;
88
use rustc_ast::attr::version::RustcVersion;
9+
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
910
use rustc_ast::token::{CommentKind, DocFragmentKind};
1011
use rustc_ast::{AttrId, AttrStyle, IntTy, UintTy};
1112
use rustc_ast_pretty::pp::Printer;

compiler/rustc_mir_transform/src/cross_crate_inline.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use rustc_middle::mir::*;
88
use rustc_middle::query::Providers;
99
use rustc_middle::ty::TyCtxt;
1010
use rustc_session::config::{InliningThreshold, OptLevel};
11-
use rustc_span::sym;
1211

1312
use crate::{inline, pass_manager as pm};
1413

@@ -37,11 +36,7 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool {
3736
}
3837

3938
// FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880
40-
#[allow(deprecated)]
41-
if tcx.has_attr(def_id, sym::autodiff_forward)
42-
|| tcx.has_attr(def_id, sym::autodiff_reverse)
43-
|| tcx.has_attr(def_id, sym::rustc_autodiff)
44-
{
39+
if find_attr!(tcx, def_id, RustcAutodiff(..)) {
4540
return true;
4641
}
4742

0 commit comments

Comments
 (0)