Skip to content

Commit f66d5ce

Browse files
committed
fix: aggressive
1 parent 608186c commit f66d5ce

4 files changed

Lines changed: 211 additions & 40 deletions

File tree

libs/@local/hashql/mir/src/pass/transform/inline/find.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,8 @@ use crate::{
2323
///
2424
/// A callsite is eligible if:
2525
/// - It's a direct call (function is a constant `FnPtr`).
26-
/// - Its target SCC has not already been inlined into this caller.
27-
///
28-
/// The SCC check prevents cycles: once we've inlined a function (or any function
29-
/// in its SCC) into a filter, we won't inline it again.
26+
/// - It's not a self-call.
27+
/// - Its target is not a loop breaker.
3028
pub(crate) struct FindCallsiteVisitor<'ctx, 'state, 'env, 'heap, A: Allocator> {
3129
/// The filter function we're finding callsites in.
3230
pub caller: DefId,
@@ -53,10 +51,10 @@ impl<'heap, A: Allocator> Visitor<'heap> for FindCallsiteVisitor<'_, '_, '_, 'he
5351
return Ok(());
5452
};
5553

56-
let target_component = self.state.components.scc(ptr);
57-
58-
// Skip if we've already inlined this SCC into this caller.
59-
if self.state.inlined.contains(self.caller, target_component) {
54+
// Skip self-calls and calls to loop breakers. Breakers are the cycle
55+
// cut points: inlining them would reintroduce the recursion that
56+
// breaker selection removed.
57+
if ptr == self.caller || self.state.loop_breakers.contains(ptr) {
6058
return Ok(());
6159
}
6260

libs/@local/hashql/mir/src/pass/transform/inline/mod.rs

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
//! aggressive inlining to fully flatten the filter logic. The aggressive phase:
3434
//! 1. Iterates up to `aggressive_inline_cutoff` times per filter.
3535
//! 2. On each iteration, inlines all eligible callsites found in the filter.
36-
//! 3. Tracks which SCCs have been inlined to prevent cycles.
36+
//! 3. Calls to loop breakers and self-calls are skipped to prevent cycles.
3737
//! 4. Emits a diagnostic if the cutoff is reached.
3838
//!
3939
//! # Budget System
@@ -53,18 +53,12 @@ use alloc::collections::BinaryHeap;
5353
use core::{alloc::Allocator, cmp, mem};
5454

5555
use hashql_core::{
56-
graph::{
57-
DirectedGraph as _,
58-
algorithms::{
59-
Tarjan, TriColorDepthFirstSearch,
60-
tarjan::{Members, SccId, StronglyConnectedComponents},
61-
},
56+
graph::algorithms::{
57+
Tarjan, TriColorDepthFirstSearch,
58+
tarjan::{Members, SccId, StronglyConnectedComponents},
6259
},
6360
heap::{BumpAllocator, Heap},
64-
id::{
65-
Id as _, IdSlice,
66-
bit_vec::{DenseBitSet, SparseBitMatrix},
67-
},
61+
id::{Id as _, IdSlice, bit_vec::DenseBitSet},
6862
span::SpanId,
6963
};
7064

@@ -214,11 +208,6 @@ struct InlineState<'ctx, 'state, 'env, 'heap, A: Allocator> {
214208
/// Calls to a breaker within its SCC are skipped during inlining.
215209
/// Calls from a breaker to non-breakers are still inlined.
216210
loop_breakers: DenseBitSet<DefId>,
217-
/// Tracks which SCCs have been inlined into each function.
218-
///
219-
/// Used to prevent cycles during aggressive inlining: once an SCC
220-
/// has been inlined into a filter, it won't be inlined again.
221-
inlined: SparseBitMatrix<DefId, SccId, A>,
222211

223212
// cost estimation properties
224213
costs: CostEstimationResidual<'heap, A>,
@@ -231,23 +220,15 @@ struct InlineState<'ctx, 'state, 'env, 'heap, A: Allocator> {
231220
}
232221

233222
impl<'heap, A: Allocator> InlineState<'_, '_, '_, 'heap, A> {
234-
/// Collect all non-recursive callsites for aggressive inlining.
223+
/// Collect all callsites for aggressive inlining.
235224
///
236225
/// Used for filter functions which bypass normal heuristics.
237-
/// Records inlined SCCs to prevent cycles in subsequent iterations.
238-
fn collect_all_callsites(&mut self, body: DefId, mem: &mut InlineStateMemory<A>) {
239-
let component = self.components.scc(body);
240-
226+
/// Self-calls are excluded to prevent panics in `get_disjoint_mut`.
227+
fn collect_all_callsites(&self, body: DefId, mem: &mut InlineStateMemory<A>) {
241228
self.graph
242229
.apply_callsites(body)
243-
.filter(|callsite| self.components.scc(callsite.target) != component)
230+
.filter(|callsite| callsite.target != body)
244231
.collect_into(&mut mem.callsites);
245-
246-
self.inlined.insert(body, component);
247-
for callsite in &mem.callsites {
248-
self.inlined
249-
.insert(body, self.components.scc(callsite.target));
250-
}
251232
}
252233

253234
/// Collect callsites using heuristic scoring and budget.
@@ -570,7 +551,6 @@ impl<A: BumpAllocator> Inline<A> {
570551
config: self.config,
571552
filters,
572553
loop_breakers,
573-
inlined: SparseBitMatrix::new_in(components.node_count(), &self.alloc),
574554
interner,
575555
graph,
576556
costs,
@@ -640,9 +620,6 @@ impl<A: BumpAllocator> Inline<A> {
640620
mem.callsites
641621
.sort_unstable_by(|lhs, rhs| lhs.kind.cmp(&rhs.kind).reverse());
642622
for callsite in mem.callsites.drain(..) {
643-
let target_component = state.components.scc(callsite.target);
644-
state.inlined.insert(filter, target_component);
645-
646623
state.inline(bodies, callsite);
647624
}
648625

libs/@local/hashql/mir/src/pass/transform/inline/tests.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,3 +1656,69 @@ fn loop_breaker_all_always_directive() {
16561656
"the selected breaker must be one of the SCC members"
16571657
);
16581658
}
1659+
1660+
/// A filter function that calls into a mutually recursive SCC.
1661+
///
1662+
/// The aggressive phase should inline non-breaker B into the filter, but the
1663+
/// breaker A (visible after B's inlining) must not be expanded. Without the
1664+
/// unconditional breaker check in `FindCallsiteVisitor`, the aggressive phase
1665+
/// would re-expand A on each iteration until the cutoff.
1666+
#[test]
1667+
fn loop_breaker_aggressive_filter_with_recursive_scc() {
1668+
let heap = Heap::new();
1669+
let interner = Interner::new(&heap);
1670+
let env = Environment::new(&heap);
1671+
1672+
let a_id = DefId::new(0);
1673+
let b_id = DefId::new(1);
1674+
let filter_id = DefId::new(2);
1675+
1676+
// A: expensive, calls B (will be selected as breaker)
1677+
let a = body!(interner, env; fn@a_id/1 -> Int {
1678+
decl n: Int, cond: Bool, t1: Int, t2: Int, t3: Int, result: Int;
1679+
bb0() {
1680+
cond = bin.== n 0;
1681+
if cond then bb1() else bb2();
1682+
},
1683+
bb1() { goto bb3(n); },
1684+
bb2() {
1685+
t1 = bin.+ n 1;
1686+
t2 = bin.+ t1 2;
1687+
t3 = apply (b_id), t2;
1688+
goto bb3(t3);
1689+
},
1690+
bb3(result) { return result; }
1691+
});
1692+
1693+
// B: cheap, calls A
1694+
let b = body!(interner, env; fn@b_id/1 -> Int {
1695+
decl x: Int, result: Int;
1696+
bb0() {
1697+
result = apply (a_id), x;
1698+
return result;
1699+
}
1700+
});
1701+
1702+
// Filter: calls B
1703+
let filter = body!(interner, env; [graph::read::filter]@filter_id/1 -> Int {
1704+
decl x: Int, result: Int;
1705+
bb0() {
1706+
result = apply (b_id), x;
1707+
return result;
1708+
}
1709+
});
1710+
1711+
let mut bodies = [a, b, filter];
1712+
1713+
assert_inline_pass(
1714+
"loop_breaker_aggressive_filter",
1715+
&mut bodies,
1716+
&mut MirContext {
1717+
heap: &heap,
1718+
env: &env,
1719+
interner: &interner,
1720+
diagnostics: DiagnosticIssues::new(),
1721+
},
1722+
InlineConfig::default(),
1723+
);
1724+
}

libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_aggressive_filter.snap

Lines changed: 130 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)