diff --git a/src/ir/module.cc b/src/ir/module.cc index e7f251103814..9d69f13a69cc 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -152,6 +152,9 @@ void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { this->functions.Set(var, func); + const auto* var_node = var.get(); + var_node->struct_info_ = func->struct_info_; + auto it = global_var_map_.find(var->name_hint); if (it != global_var_map_.end()) { ICHECK_EQ((*it).second, var); diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index fbb077ddf941..a6bdbb5521ca 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -28,10 +28,14 @@ * (those where the bound var is unused and no impure operation is used). * 2. Unused Relax functions in the module. * We detect the call chain from the entry function, and remove all unused functions. + * 3. Unused function parameters + * We detect unused parameters after removing unused VarBindings and remove them from + * function signatures and call points * * Any binding blocks that are left empty will be removed by the normalizer. */ +#include #include #include #include @@ -91,6 +95,153 @@ IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set return mod; } +// two-stage dead parameter elimination +// 1. collect all unused parameters with propagation +// 2. update all call points with new functions and inputs +std::unordered_map> CollectUsedParamIndices(const IRModule& mod) { + std::unordered_map> result; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt_func = base_func.as()) { + auto func = opt_func.value(); + std::vector used(func->params.size(), false); + + PostOrderVisit(func->body, [&](const ObjectRef& obj) { + if (auto v = obj.as()) { + Var var = ffi::GetRef(v); + for (size_t i = 0; i < func->params.size(); ++i) { + if (var.same_as(func->params[i])) { + used[i] = true; + } + } + } + }); + + std::vector indices; + for (size_t i = 0; i < used.size(); ++i) { + if (used[i]) indices.push_back(i); + } + + result[gvar] = std::move(indices); + } + } + + return result; +} + +struct CallSiteUpdater : public ExprMutator { + const std::unordered_map>& used_param_indices; + + explicit CallSiteUpdater(const std::unordered_map>& used) + : ExprMutator(std::nullopt), used_param_indices(used) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call) final { + if (auto gvar = call->op.as()) { + auto it = used_param_indices.find(gvar.value()); + if (it != used_param_indices.end()) { + const auto& used = it->second; + + if (used.size() == call->args.size()) { + return ExprMutator::VisitExpr_(call); + } + + ffi::Array new_args; + for (int idx : used) { + new_args.push_back(call->args[idx]); + } + + auto new_call = Call(call->op, new_args, call->attrs); + if (call->struct_info_.defined()) { + new_call->struct_info_ = call->struct_info_; + } + return new_call; + } + } + return ExprMutator::VisitExpr_(call); + } +}; + +IRModule RemoveUnusedParameters(IRModule mod) { + auto write_ptr = mod.CopyOnWrite(); + bool changed = true; + + do { + changed = false; + + auto used_param_indices = CollectUsedParamIndices(mod); + + for (const auto& [gvar, used] : used_param_indices) { + if (auto opt_func = mod->Lookup(gvar).as()) { + auto func = opt_func.value(); + if (used.size() < func->params.size()) { + changed = true; + break; + } + } + } + + if (!changed) break; + + std::vector worklist; + std::unordered_set visited; + std::function dfs = [&](GlobalVar gvar) { + if (visited.count(gvar)) return; + visited.insert(gvar); + + if (auto opt_func = mod->Lookup(gvar).as()) { + auto func = opt_func.value(); + PostOrderVisit(func->body, [&](const ObjectRef& obj) { + if (auto call = obj.as()) { + if (auto callee_gvar = call->op.as()) { + dfs(callee_gvar.value()); + } + } + }); + } + worklist.push_back(gvar); + }; + + for (const auto& [gvar, _] : mod->functions) { + dfs(gvar); + } + + for (const auto& gvar : worklist) { + if (auto opt_func = mod->Lookup(gvar).as()) { + auto func = opt_func.value(); + auto it = used_param_indices.find(gvar); + if (it != used_param_indices.end() && it->second.size() < func->params.size()) { + ffi::Array new_params; + for (int old_idx : used_param_indices[gvar]) { + new_params.push_back(func->params[old_idx]); + } + + Function new_func = Function(new_params, func->body, func->ret_struct_info, func->is_pure, + func->attrs, func->span); + write_ptr->Update(gvar, new_func); + } + } + } + + CallSiteUpdater updater(used_param_indices); + for (const auto& gvar : worklist) { + if (auto opt_func = mod->Lookup(gvar).as()) { + auto func = opt_func.value(); + Expr updated_body = updater(func->body); + + if (!updated_body.same_as(func->body)) { + Function new_func(func->params, updated_body, func->ret_struct_info, func->is_pure, + func->attrs, func->span); + write_ptr->Update(gvar, new_func); + } + } + } + } while (changed); + + return mod; +} + IRModule DeadCodeElimination(const IRModule& arg_mod, ffi::Array entry_function_names) { IRModule mod = arg_mod; @@ -127,9 +278,28 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, } } - // S3: remove unused functions again as some callers may be removed in S2. + // S3: remove unused parameters in each function + mod = RemoveUnusedParameters(mod); + + // S4: remove unused functions again as some callers may be removed in S2 and S3. mod = RemoveUnusedFunctions(mod, entry_functions); + // S5: remove unused variables again as some arguments may be removed in S3 + { + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto new_func = Downcast(RemoveAllUnused(opt.value())); + if (!new_func.same_as(base_func)) { + updates->Add(gvar, new_func); + } + } + } + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + } + return mod; } diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index afaa1f107d6c..78bfefeefbf9 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -65,7 +65,7 @@ class Expected: def main( x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32"), - bias: R.Tensor((26, 26), dtype="float32"), + # bias: R.Tensor((26, 26), dtype="float32"), REMOVED after dead parameter elimination pass ): with R.dataflow(): gv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) @@ -127,7 +127,7 @@ class Expected: def main( x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32"), - bias: R.Tensor((26, 26), dtype="float32"), + # bias: R.Tensor((26, 26), dtype="float32"), REMOVED after dead parameter elimination pass ): with R.dataflow(): gv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) @@ -799,5 +799,117 @@ def while_loop( verify(Before, Expected) +def test_mutual_recursion_unused_params(): + """Test that unused parameters are removed even in mutual recursion""" + + @tvm.script.ir_module + class Input: + @R.function + def func_a( + x: R.Tensor([32, 32], "float32"), + y: R.Tensor([32, 32], "float32"), # Unused in both + z: R.Tensor([32, 32], "float32"), + ): + cls = Input + out = R.add(x, z) + result = cls.func_b(out, y, z) # y passed but func_b doesn't use it + return result + + @R.function + def func_b( + a: R.Tensor([32, 32], "float32"), + b: R.Tensor([32, 32], "float32"), # Unused + c: R.Tensor([32, 32], "float32"), + ): + cls = Input + out = R.add(a, c) + result = cls.func_a(out, b, c) + return result + + @R.function + def main(): + x = R.zeros([32, 32], "float32") + y = R.ones([32, 32], "float32") + z = R.full([32, 32], R.const(2.0, dtype="float32"), "float32") + return Input.func_a(x, y, z) + + @tvm.script.ir_module + class Expected: + @R.function + def func_a(x: R.Tensor([32, 32], "float32"), z: R.Tensor([32, 32], "float32")): + cls = Expected + out = R.add(x, z) + result = cls.func_b(out, z) + return result + + @R.function + def func_b(a: R.Tensor([32, 32], "float32"), c: R.Tensor([32, 32], "float32")): + cls = Expected + out = R.add(a, c) + result = cls.func_a(out, c) + return result + + @R.function + def main(): + x = R.zeros([32, 32], "float32") + z = R.full([32, 32], R.const(2.0, dtype="float32"), "float32") + return Expected.func_a(x, z) + + verify(Input, Expected) + + +def test_deep_recursion_chain(): + """Test parameter removal through a chain of recursive calls""" + + @tvm.script.ir_module + class Input: + @R.function + def depth_1(x: R.Tensor([32], "float32"), dead1: R.Tensor([32], "float32")): # Unused + out = R.add(x, x) + result = Input.depth_2(out, dead1) + return result + + @R.function + def depth_2(a: R.Tensor([32], "float32"), dead2: R.Tensor([32], "float32")): # Unused + out = R.multiply(a, a) + result = Input.depth_3(out, dead2) + return result + + @R.function + def depth_3(b: R.Tensor([32], "float32"), dead3: R.Tensor([32], "float32")): # Unused + return R.subtract(b, b) + + @R.function + def main(): + x = R.zeros([32], "float32") + dead = R.ones([32], "float32") + return Input.depth_1(x, dead) + + @tvm.script.ir_module + class Expected: + @R.function + def depth_1(x: R.Tensor([32], "float32")): + out = R.add(x, x) + result = Expected.depth_2(out) + return result + + @R.function + def depth_2(a: R.Tensor([32], "float32")): + out = R.multiply(a, a) + result = Expected.depth_3(out) + return result + + @R.function + def depth_3(b: R.Tensor([32], "float32")): + return R.subtract(b, b) + + @R.function + def main(): + x = R.zeros([32], "float32") + return Expected.depth_1(x) + + verify(Input, Expected) + + if __name__ == "__main__": tvm.testing.main()