Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
this->functions.Set(var, func);

const auto* var_node = var.get();
var_node->struct_info_ = func->struct_info_;

auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
ICHECK_EQ((*it).second, var);
Expand Down
172 changes: 171 additions & 1 deletion src/relax/transform/dead_code_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@
* (those where the bound var is unused and no impure operation is used).
* 2. Unused Relax functions in the module.
* We detect the call chain from the entry function, and remove all unused functions.
* 3. Unused function parameters
* We detect unused parameters after removing unused VarBindings and remove them from
* function signatures and call points
*
* Any binding blocks that are left empty will be removed by the normalizer.
*/

#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/analysis.h>
#include <tvm/relax/analysis.h>
Expand Down Expand Up @@ -91,6 +95,153 @@ IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set<GlobalVar>
return mod;
}

// two-stage dead parameter elimination
// 1. collect all unused parameters with propagation
// 2. update all call points with new functions and inputs
std::unordered_map<GlobalVar, std::vector<int>> CollectUsedParamIndices(const IRModule& mod) {
std::unordered_map<GlobalVar, std::vector<int>> result;

for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt_func = base_func.as<Function>()) {
auto func = opt_func.value();
std::vector<bool> used(func->params.size(), false);

PostOrderVisit(func->body, [&](const ObjectRef& obj) {
if (auto v = obj.as<VarNode>()) {
Var var = ffi::GetRef<Var>(v);
for (size_t i = 0; i < func->params.size(); ++i) {
if (var.same_as(func->params[i])) {
used[i] = true;
}
}
}
});
Comment on lines +109 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This approach to finding used parameters involves a nested loop within PostOrderVisit, which can be inefficient for functions with large bodies. For each variable encountered, it iterates through all function parameters. This can be optimized by building a map from parameter Var to its index before the visit, and then performing a direct lookup for each variable inside the visitor. This would change the complexity from O(num_vars * num_params) to O(num_vars + num_params).

      std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual> param_to_idx;
      for (size_t i = 0; i < func->params.size(); ++i) {
        param_to_idx[func->params[i]] = i;
      }

      PostOrderVisit(func->body, [&](const ObjectRef& obj) {
        if (auto v = obj.as<VarNode>()) {
          Var var = ffi::GetRef<Var>(v);
          auto it = param_to_idx.find(var);
          if (it != param_to_idx.end()) {
            used[it->second] = true;
          }
        }
      });


std::vector<int> indices;
for (size_t i = 0; i < used.size(); ++i) {
if (used[i]) indices.push_back(i);
}

result[gvar] = std::move(indices);
}
}

return result;
}

struct CallSiteUpdater : public ExprMutator {
const std::unordered_map<GlobalVar, std::vector<int>>& used_param_indices;

explicit CallSiteUpdater(const std::unordered_map<GlobalVar, std::vector<int>>& used)
: ExprMutator(std::nullopt), used_param_indices(used) {}

using ExprMutator::VisitExpr_;

Expr VisitExpr_(const CallNode* call) final {
if (auto gvar = call->op.as<GlobalVar>()) {
auto it = used_param_indices.find(gvar.value());
if (it != used_param_indices.end()) {
const auto& used = it->second;

if (used.size() == call->args.size()) {
return ExprMutator::VisitExpr_(call);
}

ffi::Array<Expr> new_args;
for (int idx : used) {
new_args.push_back(call->args[idx]);
}

auto new_call = Call(call->op, new_args, call->attrs);
if (call->struct_info_.defined()) {
new_call->struct_info_ = call->struct_info_;
}
Comment on lines +156 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Manually copying the struct_info_ from the old call to the new call is incorrect. When arguments are removed from a call, the struct_info of the call (which describes the output) may change, especially if the return type depends on the removed arguments. The ExprMutator framework will handle re-inferring the struct_info for the new call node during normalization. Removing this manual assignment will allow the correct struct_info to be inferred.

return new_call;
}
}
return ExprMutator::VisitExpr_(call);
}
};

IRModule RemoveUnusedParameters(IRModule mod) {
auto write_ptr = mod.CopyOnWrite();
bool changed = true;

do {
changed = false;

auto used_param_indices = CollectUsedParamIndices(mod);

for (const auto& [gvar, used] : used_param_indices) {
if (auto opt_func = mod->Lookup(gvar).as<Function>()) {
auto func = opt_func.value();
if (used.size() < func->params.size()) {
changed = true;
break;
}
}
}

if (!changed) break;

std::vector<GlobalVar> worklist;
std::unordered_set<GlobalVar> visited;
std::function<void(GlobalVar)> dfs = [&](GlobalVar gvar) {
if (visited.count(gvar)) return;
visited.insert(gvar);

if (auto opt_func = mod->Lookup(gvar).as<Function>()) {
auto func = opt_func.value();
PostOrderVisit(func->body, [&](const ObjectRef& obj) {
if (auto call = obj.as<CallNode>()) {
if (auto callee_gvar = call->op.as<GlobalVar>()) {
dfs(callee_gvar.value());
}
}
});
}
worklist.push_back(gvar);
};

for (const auto& [gvar, _] : mod->functions) {
dfs(gvar);
}
Comment on lines +187 to +208
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The topological sort of the call graph is performed inside the do-while loop. However, the call graph structure does not change during dead parameter elimination. You can move the topological sort outside the loop to avoid re-computing it on every iteration, which would improve performance.


for (const auto& gvar : worklist) {
if (auto opt_func = mod->Lookup(gvar).as<Function>()) {
auto func = opt_func.value();
auto it = used_param_indices.find(gvar);
if (it != used_param_indices.end() && it->second.size() < func->params.size()) {
ffi::Array<Var> new_params;
for (int old_idx : used_param_indices[gvar]) {
new_params.push_back(func->params[old_idx]);
}

Function new_func = Function(new_params, func->body, func->ret_struct_info, func->is_pure,
func->attrs, func->span);
write_ptr->Update(gvar, new_func);
}
}
}

CallSiteUpdater updater(used_param_indices);
for (const auto& gvar : worklist) {
if (auto opt_func = mod->Lookup(gvar).as<Function>()) {
auto func = opt_func.value();
Expr updated_body = updater(func->body);

if (!updated_body.same_as(func->body)) {
Function new_func(func->params, updated_body, func->ret_struct_info, func->is_pure,
func->attrs, func->span);
write_ptr->Update(gvar, new_func);
}
}
}
} while (changed);

return mod;
}

IRModule DeadCodeElimination(const IRModule& arg_mod,
ffi::Array<ffi::String> entry_function_names) {
IRModule mod = arg_mod;
Expand Down Expand Up @@ -127,9 +278,28 @@ IRModule DeadCodeElimination(const IRModule& arg_mod,
}
}

// S3: remove unused functions again as some callers may be removed in S2.
// S3: remove unused parameters in each function
mod = RemoveUnusedParameters(mod);

// S4: remove unused functions again as some callers may be removed in S2 and S3.
mod = RemoveUnusedFunctions(mod, entry_functions);

// S5: remove unused variables again as some arguments may be removed in S3
{
IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<Function>()) {
auto new_func = Downcast<Function>(RemoveAllUnused(opt.value()));
if (!new_func.same_as(base_func)) {
updates->Add(gvar, new_func);
}
}
}
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
}
Comment on lines +287 to +301
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for removing unused variables (S5) is identical to the block for S2 (lines 266-279). To improve maintainability and avoid code duplication, you could extract this logic into a helper function, for example RemoveUnusedVariables(IRModule mod), and call it in both places.


return mod;
}

Expand Down
116 changes: 114 additions & 2 deletions tests/python/relax/test_transform_dead_code_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Expected:
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
bias: R.Tensor((26, 26), dtype="float32"),
# bias: R.Tensor((26, 26), dtype="float32"), REMOVED after dead parameter elimination pass
):
with R.dataflow():
gv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1])
Expand Down Expand Up @@ -127,7 +127,7 @@ class Expected:
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
bias: R.Tensor((26, 26), dtype="float32"),
# bias: R.Tensor((26, 26), dtype="float32"), REMOVED after dead parameter elimination pass
):
with R.dataflow():
gv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1])
Expand Down Expand Up @@ -799,5 +799,117 @@ def while_loop(
verify(Before, Expected)


def test_mutual_recursion_unused_params():
"""Test that unused parameters are removed even in mutual recursion"""

@tvm.script.ir_module
class Input:
@R.function
def func_a(
x: R.Tensor([32, 32], "float32"),
y: R.Tensor([32, 32], "float32"), # Unused in both
z: R.Tensor([32, 32], "float32"),
):
cls = Input
out = R.add(x, z)
result = cls.func_b(out, y, z) # y passed but func_b doesn't use it
return result

@R.function
def func_b(
a: R.Tensor([32, 32], "float32"),
b: R.Tensor([32, 32], "float32"), # Unused
c: R.Tensor([32, 32], "float32"),
):
cls = Input
out = R.add(a, c)
result = cls.func_a(out, b, c)
return result

@R.function
def main():
x = R.zeros([32, 32], "float32")
y = R.ones([32, 32], "float32")
z = R.full([32, 32], R.const(2.0, dtype="float32"), "float32")
return Input.func_a(x, y, z)

@tvm.script.ir_module
class Expected:
@R.function
def func_a(x: R.Tensor([32, 32], "float32"), z: R.Tensor([32, 32], "float32")):
cls = Expected
out = R.add(x, z)
result = cls.func_b(out, z)
return result

@R.function
def func_b(a: R.Tensor([32, 32], "float32"), c: R.Tensor([32, 32], "float32")):
cls = Expected
out = R.add(a, c)
result = cls.func_a(out, c)
return result

@R.function
def main():
x = R.zeros([32, 32], "float32")
z = R.full([32, 32], R.const(2.0, dtype="float32"), "float32")
return Expected.func_a(x, z)

verify(Input, Expected)


def test_deep_recursion_chain():
"""Test parameter removal through a chain of recursive calls"""

@tvm.script.ir_module
class Input:
@R.function
def depth_1(x: R.Tensor([32], "float32"), dead1: R.Tensor([32], "float32")): # Unused
out = R.add(x, x)
result = Input.depth_2(out, dead1)
return result

@R.function
def depth_2(a: R.Tensor([32], "float32"), dead2: R.Tensor([32], "float32")): # Unused
out = R.multiply(a, a)
result = Input.depth_3(out, dead2)
return result

@R.function
def depth_3(b: R.Tensor([32], "float32"), dead3: R.Tensor([32], "float32")): # Unused
return R.subtract(b, b)

@R.function
def main():
x = R.zeros([32], "float32")
dead = R.ones([32], "float32")
return Input.depth_1(x, dead)

@tvm.script.ir_module
class Expected:
@R.function
def depth_1(x: R.Tensor([32], "float32")):
out = R.add(x, x)
result = Expected.depth_2(out)
return result

@R.function
def depth_2(a: R.Tensor([32], "float32")):
out = R.multiply(a, a)
result = Expected.depth_3(out)
return result

@R.function
def depth_3(b: R.Tensor([32], "float32")):
return R.subtract(b, b)

@R.function
def main():
x = R.zeros([32], "float32")
return Expected.depth_1(x)

verify(Input, Expected)


if __name__ == "__main__":
tvm.testing.main()