-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax] Add Dead Parameter Elimination #18720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d24a105
c1420f9
cad4eb2
a13d489
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,10 +28,14 @@ | |
| * (those where the bound var is unused and no impure operation is used). | ||
| * 2. Unused Relax functions in the module. | ||
| * We detect the call chain from the entry function, and remove all unused functions. | ||
| * 3. Unused function parameters | ||
| * We detect unused parameters after removing unused VarBindings and remove them from | ||
| * function signatures and call points | ||
| * | ||
| * Any binding blocks that are left empty will be removed by the normalizer. | ||
| */ | ||
|
|
||
| #include <tvm/ffi/cast.h> | ||
| #include <tvm/ffi/reflection/registry.h> | ||
| #include <tvm/ir/analysis.h> | ||
| #include <tvm/relax/analysis.h> | ||
|
|
@@ -91,6 +95,153 @@ IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set<GlobalVar> | |
| return mod; | ||
| } | ||
|
|
||
| // two-stage dead parameter elimination | ||
| // 1. collect all unused parameters with propagation | ||
| // 2. update all call points with new functions and inputs | ||
| std::unordered_map<GlobalVar, std::vector<int>> CollectUsedParamIndices(const IRModule& mod) { | ||
| std::unordered_map<GlobalVar, std::vector<int>> result; | ||
|
|
||
| for (const auto& [gvar, base_func] : mod->functions) { | ||
| if (auto opt_func = base_func.as<Function>()) { | ||
| auto func = opt_func.value(); | ||
| std::vector<bool> used(func->params.size(), false); | ||
|
|
||
| PostOrderVisit(func->body, [&](const ObjectRef& obj) { | ||
| if (auto v = obj.as<VarNode>()) { | ||
| Var var = ffi::GetRef<Var>(v); | ||
| for (size_t i = 0; i < func->params.size(); ++i) { | ||
| if (var.same_as(func->params[i])) { | ||
| used[i] = true; | ||
| } | ||
| } | ||
| } | ||
| }); | ||
|
|
||
| std::vector<int> indices; | ||
| for (size_t i = 0; i < used.size(); ++i) { | ||
| if (used[i]) indices.push_back(i); | ||
| } | ||
|
|
||
| result[gvar] = std::move(indices); | ||
| } | ||
| } | ||
|
|
||
| return result; | ||
| } | ||
|
|
||
| struct CallSiteUpdater : public ExprMutator { | ||
| const std::unordered_map<GlobalVar, std::vector<int>>& used_param_indices; | ||
|
|
||
| explicit CallSiteUpdater(const std::unordered_map<GlobalVar, std::vector<int>>& used) | ||
| : ExprMutator(std::nullopt), used_param_indices(used) {} | ||
|
|
||
| using ExprMutator::VisitExpr_; | ||
|
|
||
| Expr VisitExpr_(const CallNode* call) final { | ||
| if (auto gvar = call->op.as<GlobalVar>()) { | ||
| auto it = used_param_indices.find(gvar.value()); | ||
| if (it != used_param_indices.end()) { | ||
| const auto& used = it->second; | ||
|
|
||
| if (used.size() == call->args.size()) { | ||
| return ExprMutator::VisitExpr_(call); | ||
| } | ||
|
|
||
| ffi::Array<Expr> new_args; | ||
| for (int idx : used) { | ||
| new_args.push_back(call->args[idx]); | ||
| } | ||
|
|
||
| auto new_call = Call(call->op, new_args, call->attrs); | ||
| if (call->struct_info_.defined()) { | ||
| new_call->struct_info_ = call->struct_info_; | ||
| } | ||
|
Comment on lines
+156
to
+158
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Manually copying the |
||
| return new_call; | ||
| } | ||
| } | ||
| return ExprMutator::VisitExpr_(call); | ||
| } | ||
| }; | ||
|
|
||
| IRModule RemoveUnusedParameters(IRModule mod) { | ||
| auto write_ptr = mod.CopyOnWrite(); | ||
| bool changed = true; | ||
|
|
||
| do { | ||
| changed = false; | ||
|
|
||
| auto used_param_indices = CollectUsedParamIndices(mod); | ||
|
|
||
| for (const auto& [gvar, used] : used_param_indices) { | ||
| if (auto opt_func = mod->Lookup(gvar).as<Function>()) { | ||
| auto func = opt_func.value(); | ||
| if (used.size() < func->params.size()) { | ||
| changed = true; | ||
| break; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (!changed) break; | ||
|
|
||
| std::vector<GlobalVar> worklist; | ||
| std::unordered_set<GlobalVar> visited; | ||
| std::function<void(GlobalVar)> dfs = [&](GlobalVar gvar) { | ||
| if (visited.count(gvar)) return; | ||
| visited.insert(gvar); | ||
|
|
||
| if (auto opt_func = mod->Lookup(gvar).as<Function>()) { | ||
| auto func = opt_func.value(); | ||
| PostOrderVisit(func->body, [&](const ObjectRef& obj) { | ||
| if (auto call = obj.as<CallNode>()) { | ||
| if (auto callee_gvar = call->op.as<GlobalVar>()) { | ||
| dfs(callee_gvar.value()); | ||
| } | ||
| } | ||
| }); | ||
| } | ||
| worklist.push_back(gvar); | ||
| }; | ||
|
|
||
| for (const auto& [gvar, _] : mod->functions) { | ||
| dfs(gvar); | ||
| } | ||
|
Comment on lines
+187
to
+208
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| for (const auto& gvar : worklist) { | ||
| if (auto opt_func = mod->Lookup(gvar).as<Function>()) { | ||
| auto func = opt_func.value(); | ||
| auto it = used_param_indices.find(gvar); | ||
| if (it != used_param_indices.end() && it->second.size() < func->params.size()) { | ||
| ffi::Array<Var> new_params; | ||
| for (int old_idx : used_param_indices[gvar]) { | ||
| new_params.push_back(func->params[old_idx]); | ||
| } | ||
|
|
||
| Function new_func = Function(new_params, func->body, func->ret_struct_info, func->is_pure, | ||
| func->attrs, func->span); | ||
| write_ptr->Update(gvar, new_func); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| CallSiteUpdater updater(used_param_indices); | ||
| for (const auto& gvar : worklist) { | ||
| if (auto opt_func = mod->Lookup(gvar).as<Function>()) { | ||
| auto func = opt_func.value(); | ||
| Expr updated_body = updater(func->body); | ||
|
|
||
| if (!updated_body.same_as(func->body)) { | ||
| Function new_func(func->params, updated_body, func->ret_struct_info, func->is_pure, | ||
| func->attrs, func->span); | ||
| write_ptr->Update(gvar, new_func); | ||
| } | ||
| } | ||
| } | ||
| } while (changed); | ||
|
|
||
| return mod; | ||
| } | ||
|
|
||
| IRModule DeadCodeElimination(const IRModule& arg_mod, | ||
| ffi::Array<ffi::String> entry_function_names) { | ||
| IRModule mod = arg_mod; | ||
|
|
@@ -127,9 +278,28 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, | |
| } | ||
| } | ||
|
|
||
| // S3: remove unused functions again as some callers may be removed in S2. | ||
| // S3: remove unused parameters in each function | ||
| mod = RemoveUnusedParameters(mod); | ||
|
|
||
| // S4: remove unused functions again as some callers may be removed in S2 and S3. | ||
| mod = RemoveUnusedFunctions(mod, entry_functions); | ||
|
|
||
| // S5: remove unused variables again as some arguments may be removed in S3 | ||
| { | ||
| IRModule updates; | ||
| for (const auto& [gvar, base_func] : mod->functions) { | ||
| if (auto opt = base_func.as<Function>()) { | ||
| auto new_func = Downcast<Function>(RemoveAllUnused(opt.value())); | ||
| if (!new_func.same_as(base_func)) { | ||
| updates->Add(gvar, new_func); | ||
| } | ||
| } | ||
| } | ||
| if (updates->functions.size()) { | ||
| mod.CopyOnWrite()->Update(updates); | ||
| } | ||
| } | ||
|
Comment on lines
+287
to
+301
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| return mod; | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach to finding used parameters involves a nested loop within
PostOrderVisit, which can be inefficient for functions with large bodies. For each variable encountered, it iterates through all function parameters. This can be optimized by building a map from parameterVarto its index before the visit, and then performing a direct lookup for each variable inside the visitor. This would change the complexity from O(num_vars * num_params) to O(num_vars + num_params).