From 2007b1afee2fef162267a92bad4bcc155ed9546f Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Wed, 6 May 2026 00:04:42 +0100 Subject: [PATCH 1/2] feat: add pretty run report --- python/egglog/egraph.py | 15 +++--- python/egglog/egraph_state.py | 26 ++++++++--- python/egglog/run_report.py | 86 +++++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 13 deletions(-) create mode 100644 python/egglog/run_report.py diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 8fc6643b..a60f13c3 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -42,6 +42,7 @@ from .egraph_state import * from .ipython_magic import IN_IPYTHON from .pretty import pretty_decl +from .run_report import PrettyRunReport from .runtime import * from .thunk import * @@ -953,15 +954,15 @@ def output(self) -> None: raise NotImplementedError(msg) @overload - def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ... + def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> PrettyRunReport: ... @overload - def run(self, schedule: Schedule, /) -> bindings.RunReport: ... + def run(self, schedule: Schedule, /) -> PrettyRunReport: ... @_TRACER.start_as_current_span("run") def run( self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None - ) -> bindings.RunReport: + ) -> PrettyRunReport: """ Run the egraph until the given limit or until the given facts are true. """ @@ -969,20 +970,20 @@ def run( limit_or_schedule = run(ruleset, *until) * limit_or_schedule return self._run_schedule(limit_or_schedule) - def _run_schedule(self, schedule: Schedule) -> bindings.RunReport: + def _run_schedule(self, schedule: Schedule) -> PrettyRunReport: self._add_decls(schedule) cmd = self._state.run_schedule_to_egg(schedule.schedule) (command_output,) = self._run_program(cmd) assert isinstance(command_output, bindings.RunScheduleOutput) - return command_output.report + return PrettyRunReport.from_bindings(command_output.report, self._state) - def stats(self) -> bindings.RunReport: + def stats(self) -> PrettyRunReport: """ Returns the overall run report for the egraph. """ (output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None)) assert isinstance(output, bindings.OverallStatistics) - return output.report + return PrettyRunReport.from_bindings(output.report, self._state) def check_bool(self, *facts: FactLike) -> bool: """ diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index 1d65aeff..650d0d01 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -76,6 +76,8 @@ class EGraphState: type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict) egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict) + egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict) + # Cache of egg expressions for converting to egg expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict) @@ -247,6 +249,14 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912 case _: assert_never(schedule) + def translate_rule_key(self, egglog_key: str) -> str: + """ + Translate an egglog rule name to its Python representation. + """ + if egglog_key in self.egg_rule_to_command_decl: + return pretty_decl(self.__egg_decls__, self.egg_rule_to_command_decl[egglog_key]) + return egglog_key + def ruleset_to_egg(self, ident: Ident) -> None: """ Registers a ruleset if it's not already registered. @@ -289,13 +299,15 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command self._expr_to_egg(rhs), [self.fact_to_egg(c) for c in conditions], ) - return ( - bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume) - if isinstance(cmd, RewriteDecl) - else bindings.BiRewriteCommand(str(ruleset), rewrite) - ) + if isinstance(cmd, RewriteDecl): + egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume) + else: + egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite) + + self.egg_rule_to_command_decl[str(egg_cmd)] = cmd + return egg_cmd case RuleDecl(head, body, name): - return bindings.RuleCommand( + egg_cmd = bindings.RuleCommand( bindings.Rule( span(), [self.action_to_egg(a) for a in head], @@ -304,6 +316,8 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command str(ruleset), ) ) + self.egg_rule_to_command_decl[str(egg_cmd)] = cmd + return egg_cmd # TODO: Replace with just constants value and looking at REF of function case DefaultRewriteDecl(ref, expr, subsume): sig = self.__egg_decls__.get_callable_decl(ref).signature diff --git a/python/egglog/run_report.py b/python/egglog/run_report.py new file mode 100644 index 00000000..233906c3 --- /dev/null +++ b/python/egglog/run_report.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta + +from . import bindings +from .egraph_state import EGraphState + + +@dataclass +class PrettyRuleReport: + plan: bindings.Plan | None + search_and_apply_time: timedelta + num_matches: int + + @classmethod + def from_bindings(cls, report: bindings.RuleReport) -> PrettyRuleReport: + return cls( + plan=report.plan, + search_and_apply_time=report.search_and_apply_time, + num_matches=report.num_matches, + ) + + +@dataclass +class PrettyRuleSetReport: + changed: bool + rule_reports: dict[str, list[PrettyRuleReport]] + search_and_apply_time: timedelta + merge_time: timedelta + + @classmethod + def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) -> PrettyRuleSetReport: + return cls( + changed=report.changed, + rule_reports={ + translate_key(k): [PrettyRuleReport.from_bindings(rr) for rr in v] + for k, v in report.rule_reports.items() + }, + search_and_apply_time=report.search_and_apply_time, + merge_time=report.merge_time, + ) + + +@dataclass +class PrettyIterationReport: + rule_set_report: PrettyRuleSetReport + rebuild_time: timedelta + + @classmethod + def from_bindings(cls, report: bindings.IterationReport, translate_key: callable) -> PrettyIterationReport: + return cls( + rule_set_report=PrettyRuleSetReport.from_bindings(report.rule_set_report, translate_key), + rebuild_time=report.rebuild_time, + ) + + +@dataclass +class PrettyRunReport: + """Python-friendly wrapper around bindings.RunReport.""" + + iterations: list[PrettyIterationReport] + updated: bool + search_and_apply_time_per_rule: dict[str, timedelta] + num_matches_per_rule: dict[str, int] + search_and_apply_time_per_ruleset: dict[str, timedelta] + merge_time_per_ruleset: dict[str, timedelta] + rebuild_time_per_ruleset: dict[str, timedelta] + + @classmethod + def from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> PrettyRunReport: + return cls( + iterations=[PrettyIterationReport.from_bindings(it, state.translate_rule_key) for it in report.iterations], + updated=report.updated, + search_and_apply_time_per_rule={ + state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items() + }, + num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()}, + search_and_apply_time_per_ruleset={ + state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_ruleset.items() + }, + merge_time_per_ruleset={state.translate_rule_key(k): v for k, v in report.merge_time_per_ruleset.items()}, + rebuild_time_per_ruleset={ + state.translate_rule_key(k): v for k, v in report.rebuild_time_per_ruleset.items() + }, + ) From 86133fa2a1d4e388100400f5ae0e91f2a4ff6668 Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Wed, 6 May 2026 00:08:47 +0100 Subject: [PATCH 2/2] feat: add test for pretty run report --- python/tests/test_run_report.py | 200 ++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 python/tests/test_run_report.py diff --git a/python/tests/test_run_report.py b/python/tests/test_run_report.py new file mode 100644 index 00000000..38aa9223 --- /dev/null +++ b/python/tests/test_run_report.py @@ -0,0 +1,200 @@ +# mypy: disable-error-code="empty-body" +from __future__ import annotations + +from datetime import timedelta + +from egglog import * + + +class TestPrettyRunReport: + def _setup_simple_egraph(self): + egraph = EGraph() + + class Num(Expr): + def __init__(self, n: i64Like) -> None: ... + def __add__(self, other: Num) -> Num: ... + + x, y = vars_("x y", Num) + egraph.register(rewrite(x + y).to(y + x)) + egraph.register(Num(1) + Num(2)) + return egraph + + def test_run_returns_pretty_report(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + assert type(report).__name__ == "PrettyRunReport" + + def test_stats_returns_pretty_report(self): + egraph = self._setup_simple_egraph() + egraph.run(10) + report = egraph.stats() + assert type(report).__name__ == "PrettyRunReport" + + def test_rule_names_translated_in_top_level_dicts(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + + for key in report.search_and_apply_time_per_rule: + assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" + assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}" + + for key in report.num_matches_per_rule: + assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" + assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}" + + def test_rule_names_translated_in_iterations(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + + assert len(report.iterations) > 0 + for iteration in report.iterations: + for key in iteration.rule_set_report.rule_reports: + assert "__main__" not in key, f"Iteration rule key not translated: {key}" + assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" + + def test_updated_field(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + assert isinstance(report.updated, bool) + assert report.updated is True + + def test_num_matches(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + + total_matches = sum(report.num_matches_per_rule.values()) + assert total_matches > 0 + + def test_timedelta_types(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + + for v in report.search_and_apply_time_per_rule.values(): + assert isinstance(v, timedelta) + for v in report.search_and_apply_time_per_ruleset.values(): + assert isinstance(v, timedelta) + for v in report.merge_time_per_ruleset.values(): + assert isinstance(v, timedelta) + for v in report.rebuild_time_per_ruleset.values(): + assert isinstance(v, timedelta) + + def test_iteration_reports_are_pretty(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + + for it in report.iterations: + assert type(it).__name__ == "PrettyIterationReport" + assert type(it.rule_set_report).__name__ == "PrettyRuleSetReport" + for rule_reports in it.rule_set_report.rule_reports.values(): + for rr in rule_reports: + assert type(rr).__name__ == "PrettyRuleReport" + + def test_str_no_egglog_sexprs(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + output = str(report) + + assert "(rewrite" not in output, f"str() still contains egglog s-expressions:\n{output}" + assert "__main__" not in output, f"str() still contains mangled names:\n{output}" + + def test_multiple_rules(self): + egraph = EGraph() + + class Math(Expr): + def __init__(self, value: i64Like) -> None: ... + def __add__(self, other: Math) -> Math: ... + def __mul__(self, other: Math) -> Math: ... + + a, b = vars_("a b", Math) + egraph.register( + rewrite(a + b).to(b + a), + rewrite(a * b).to(b * a), + ) + egraph.register(Math(1) + Math(2), Math(3) * Math(4)) + report = egraph.run(10) + + # should have two distinct translated rule keys + rule_keys = list(report.search_and_apply_time_per_rule.keys()) + assert len(rule_keys) == 2 + for key in rule_keys: + assert "__main__" not in key, f"Key not translated: {key}" + + def test_empty_run(self): + egraph = EGraph() + report = egraph.run(1) + assert type(report).__name__ == "PrettyRunReport" + assert isinstance(report.updated, bool) + + def test_named_rule(self): + egraph = EGraph() + + class Num(Expr): + def __init__(self, n: i64Like) -> None: ... + def __add__(self, other: Num) -> Num: ... + + x, y = vars_("x y", Num) + egraph.register(rule(x + y, name="comm").then(union(x + y).with_(y + x))) + egraph.register(Num(1) + Num(2)) + report = egraph.run(10) + + output = str(report) + assert "__main__" not in output, f"str() still contains mangled names:\n{output}" + + def test_unnamed_rule_decl(self): + egraph = EGraph() + + class Num(Expr): + def __init__(self, n: i64Like) -> None: ... + def __add__(self, other: Num) -> Num: ... + + x, y = vars_("x y", Num) + egraph.register(rule(x + y).then(union(x + y).with_(y + x))) + egraph.register(Num(1) + Num(2)) + report = egraph.run(10) + + output = str(report) + assert "__main__" not in output, f"Unnamed RuleDecl key not translated:\n{output}" + # Should contain Python rule() syntax somewhere in the keys + rule_keys = list(report.search_and_apply_time_per_rule.keys()) + assert len(rule_keys) > 0 + for key in rule_keys: + assert "__main__" not in key, f"RuleDecl key not translated: {key}" + + def test_birewrite_decl(self): + egraph = EGraph() + + class Num(Expr): + def __init__(self, n: i64Like) -> None: ... + def __add__(self, other: Num) -> Num: ... + def __mul__(self, other: Num) -> Num: ... + + x, y = vars_("x y", Num) + egraph.register(birewrite(x + y).to(y + x)) + egraph.register(Num(1) + Num(2)) + report = egraph.run(10) + + output = str(report) + assert "__main__" not in output, f"BiRewriteDecl key not translated:\n{output}" + rule_keys = list(report.search_and_apply_time_per_rule.keys()) + assert len(rule_keys) > 0 + for key in rule_keys: + assert "__main__" not in key, f"BiRewriteDecl key not translated: {key}" + assert "birewrite" in key, f"Expected birewrite() syntax, got: {key}" + + def test_rewrite_decl(self): + egraph = EGraph() + + class Num(Expr): + def __init__(self, n: i64Like) -> None: ... + def __add__(self, other: Num) -> Num: ... + + x, y = vars_("x y", Num) + egraph.register(rewrite(x + y).to(y + x)) + egraph.register(Num(1) + Num(2)) + report = egraph.run(10) + + rule_keys = list(report.search_and_apply_time_per_rule.keys()) + assert len(rule_keys) == 1 + key = rule_keys[0] + assert "rewrite" in key, f"Expected rewrite() syntax, got: {key}" + assert "__main__" not in key, f"RewriteDecl key not translated: {key}"