Skip to content

Commit 98e8a45

Browse files
authored
Delete unused refs from the tuple argument entirely (#221)
* First pass at deleting refs from the tuple entirely * Fix bug
1 parent 80969d0 commit 98e8a45

3 files changed

Lines changed: 58 additions & 32 deletions

File tree

src/bbcode.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ const IdToIdDict = Dict{ID,ID}
371371
function replace_ids(d::IdToIdDict, inst::NewInstruction)
372372
return NewInstruction(inst; stmt=replace_ids(d, inst.stmt))
373373
end
374+
replace_ids(d::IdToIdDict, x::ID) = get(d, x, x)
374375
function replace_ids(d::IdToIdDict, x::ReturnNode)
375376
return isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x
376377
end

src/refelim.jl

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,6 @@ actually needed across boundaries, and eliminates all calls to `set_ref_at!` tha
3737
to be retained.
3838
3939
Returns a tuple of the modified `BBCode` and the modified `refs` tuple.
40-
41-
!!! note
42-
Right now, `eliminate_refs` does not remove dead refs from the `refs` tuple itself (so the
43-
TapedTask will be constructed with the same `refs` tuple as before). We simply leave those
44-
refs as unused (i.e., they will be initialised with nothing, and never read from or
45-
written to.) In principle, we could also slim down the `refs` tuple itself by removing
46-
the dead refs from it. This is left as a future optimisation (and the signature of this
47-
function is designed to allow for this in the future).
4840
"""
4941
function eliminate_refs(ir::BBCode, refs::Tuple)
5042
# The `refs` tuple contains a series of `Ref`s which are used to maintain function state
@@ -189,21 +181,15 @@ function eliminate_refs(ir::BBCode, refs::Tuple)
189181
# Only the refs that are live at the end of some basic block anywhere in the function
190182
# need to be kept. Note that the last ref in `refs` is always mandatory: it's the one
191183
# that stores the return block (i.e., how far through the function it's progressed).
192-
necessary_ref_ids = sort!(collect(union(values(live_out)...)))
193-
unnecessary_ref_ids = setdiff(1:(length(refs) - 1), necessary_ref_ids)
184+
necessary_ref_ids = sort!(vcat(length(refs), collect(union(values(live_out)...))))
185+
unnecessary_ref_ids = setdiff(1:length(refs), necessary_ref_ids)
194186

195-
# TODO(penelopeysm): We could reduce the size of the ref tuple itself, by dropping refs
196-
# that are never used. I think this is not super important right now: it doesn't really
197-
# hurt to have extra refs lying around in the tuple, because they're just initialised to
198-
# essentially null pointers and never read/written to. But in principle we could get rid
199-
# of them too.
200-
#
201-
# new_refs = tuple(
202-
# [ref for (i, ref) in enumerate(refs) if !(i in unnecessary_ref_ids)]...
203-
# )
204-
# refid_to_new_refid_map = Dict{Int,Int}(
205-
# necessary_ref_ids[i] => i for i in eachindex(necessary_ref_ids)
206-
# )
187+
new_refs = map(i -> refs[i], tuple(necessary_ref_ids...))
188+
# Suppose that we want to keep refs 1, 4, and 5. Then this map would be Dict(1 => 1, 4
189+
# => 2, 5 => 3).
190+
refid_to_new_refid_map = Dict{Int,Int}(
191+
refid => i for (i, refid) in enumerate(necessary_ref_ids)
192+
)
207193

208194
# We now need to go through the IR and remove calls that get/set the unnecessary refs.
209195
new_bblocks = map(ir.blocks) do block
@@ -263,8 +249,14 @@ function eliminate_refs(ir::BBCode, refs::Tuple)
263249
else
264250
error("Unexpected value argument to set_ref_at!: $value_arg")
265251
end
266-
ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst)
267-
push!(new_insts, (id, ninst))
252+
ninst = Expr(
253+
:call,
254+
Libtask.set_ref_at!,
255+
replace_ids(old_ssaid_to_new_ssaid_map, inst.stmt.args[2]),
256+
refid_to_new_refid_map[refid],
257+
replace_ids(old_ssaid_to_new_ssaid_map, inst.stmt.args[4]),
258+
)
259+
push!(new_insts, (id, new_inst(ninst)))
268260
end
269261
elseif call_func == Libtask.get_ref_at
270262
refid = inst.stmt.args[3]
@@ -279,30 +271,63 @@ function eliminate_refs(ir::BBCode, refs::Tuple)
279271
old_ssaid_to_new_ssaid_map[id] = refid_to_ssaid_map[refid]
280272
else
281273
# It's a get that we legitimately still need.
282-
ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst)
283-
push!(new_insts, (id, ninst))
274+
ninst = Expr(
275+
:call,
276+
Libtask.get_ref_at,
277+
replace_ids(old_ssaid_to_new_ssaid_map, inst.stmt.args[2]),
278+
refid_to_new_refid_map[refid],
279+
)
280+
push!(new_insts, (id, new_inst(ninst)))
284281
end
285282
else
286283
# Some other call instruction.
287284
ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst)
288285
push!(new_insts, (id, ninst))
289286
end
287+
elseif inst.stmt isa IDPhiNode
288+
# Replace any SSA IDs in the phi node.
289+
ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst)
290+
# then replace any TupleRefs with the new ref id
291+
new_values = Vector{Any}(undef, length(ninst.stmt.values))
292+
for n in eachindex(ninst.stmt.values)
293+
if isassigned(ninst.stmt.values, n)
294+
val = ninst.stmt.values[n]
295+
new_values[n] = if val isa Libtask.TupleRef
296+
if !haskey(refid_to_new_refid_map, val.n)
297+
# This should never happen, because if `val.n` was in the
298+
# phi node, it always counts as an upwards-exposed use of
299+
# that ref, and should therefore always be included in
300+
# `necessary_ref_ids`.
301+
error("found TupleRef with unused ref id $(val.n)")
302+
end
303+
TupleRef(refid_to_new_refid_map[val.n])
304+
else
305+
val
306+
end
307+
end
308+
end
309+
push!(new_insts, (id, new_inst(IDPhiNode(ninst.stmt.edges, new_values))))
290310
else
291-
# Some other (non-call) instruction.
311+
# Some other (non-call, non-PhiNode) instruction.
292312
ninst = replace_ids(old_ssaid_to_new_ssaid_map, inst)
293313
push!(new_insts, (id, ninst))
294314
end
295315
end
296316
return BBlock(block.id, new_insts)
297317
end
298318

319+
# The tuple of refs is passed in as the first argument to the IR, so we need to update
320+
# the types.
321+
new_argtypes = vcat(typeof(new_refs), copy(ir.argtypes[2:end]))
322+
299323
new_ir = @static if VERSION >= v"1.12-"
300-
BBCode(new_bblocks, ir.argtypes, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds)
324+
BBCode(
325+
new_bblocks, new_argtypes, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds
326+
)
301327
else
302-
BBCode(new_bblocks, ir.argtypes, ir.sptypes, ir.linetable, ir.meta)
328+
BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta)
303329
end
304-
# return ir, refs
305-
return new_ir, refs
330+
return new_ir, new_refs
306331
end
307332

308333
# Return a vector of block IDs in reverse postorder on the reverse CFG (i.e., the CFG where

src/test_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ function test_cases()
226226
"default kwarg tester", nothing, (default_kwarg_tester, 4.0), (;), [], allocs
227227
),
228228
Testcase(
229-
"final statment produce",
229+
"final statement produce",
230230
nothing,
231231
(final_statement_produce,),
232232
nothing,

0 commit comments

Comments
 (0)