-
Notifications
You must be signed in to change notification settings - Fork 22
feat: add pretty run report #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,6 +76,8 @@ class EGraphState: | |
| type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict) | ||
| egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict) | ||
|
|
||
| egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict) | ||
|
|
||
| # Cache of egg expressions for converting to egg | ||
| expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict) | ||
|
|
||
|
|
@@ -247,6 +249,14 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912 | |
| case _: | ||
| assert_never(schedule) | ||
|
|
||
| def translate_rule_key(self, egglog_key: str) -> str: | ||
| """ | ||
| Translate an egglog rule name to its Python representation. | ||
| """ | ||
| if egglog_key in self.egg_rule_to_command_decl: | ||
| return pretty_decl(self.__egg_decls__, self.egg_rule_to_command_decl[egglog_key]) | ||
| return egglog_key | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any cases where it won't find the saved command? If possible seems better to raise an exception instead of falling back. |
||
|
|
||
| def ruleset_to_egg(self, ident: Ident) -> None: | ||
| """ | ||
| Registers a ruleset if it's not already registered. | ||
|
|
@@ -289,13 +299,15 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command | |
| self._expr_to_egg(rhs), | ||
| [self.fact_to_egg(c) for c in conditions], | ||
| ) | ||
| return ( | ||
| bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume) | ||
| if isinstance(cmd, RewriteDecl) | ||
| else bindings.BiRewriteCommand(str(ruleset), rewrite) | ||
| ) | ||
| if isinstance(cmd, RewriteDecl): | ||
| egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume) | ||
| else: | ||
| egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite) | ||
|
|
||
| self.egg_rule_to_command_decl[str(egg_cmd)] = cmd | ||
| return egg_cmd | ||
| case RuleDecl(head, body, name): | ||
| return bindings.RuleCommand( | ||
| egg_cmd = bindings.RuleCommand( | ||
| bindings.Rule( | ||
| span(), | ||
| [self.action_to_egg(a) for a in head], | ||
|
|
@@ -304,6 +316,8 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command | |
| str(ruleset), | ||
| ) | ||
| ) | ||
| self.egg_rule_to_command_decl[str(egg_cmd)] = cmd | ||
| return egg_cmd | ||
| # TODO: Replace with just constants value and looking at REF of function | ||
| case DefaultRewriteDecl(ref, expr, subsume): | ||
| sig = self.__egg_decls__.get_callable_decl(ref).signature | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| from datetime import timedelta | ||
|
|
||
| from . import bindings | ||
| from .egraph_state import EGraphState | ||
|
|
||
|
|
||
| @dataclass | ||
| class PrettyRuleReport: | ||
| plan: bindings.Plan | None | ||
| search_and_apply_time: timedelta | ||
| num_matches: int | ||
|
|
||
| @classmethod | ||
| def from_bindings(cls, report: bindings.RuleReport) -> PrettyRuleReport: | ||
| return cls( | ||
| plan=report.plan, | ||
| search_and_apply_time=report.search_and_apply_time, | ||
| num_matches=report.num_matches, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class PrettyRuleSetReport: | ||
| changed: bool | ||
| rule_reports: dict[str, list[PrettyRuleReport]] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we try storing them as actual rewrites/rules here instead of strings? That would allow users at runtime for example to extract the rule, run it on something, or compare equality without having to compare string equality. Also just in others forms I usually try to return runtime objects instead of strings. This would also impact how they are stored in e-graph state. Then also the pretty print won't show them as strings. |
||
| search_and_apply_time: timedelta | ||
| merge_time: timedelta | ||
|
|
||
| @classmethod | ||
| def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) -> PrettyRuleSetReport: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this callable be more strongly typed? |
||
| return cls( | ||
| changed=report.changed, | ||
| rule_reports={ | ||
| translate_key(k): [PrettyRuleReport.from_bindings(rr) for rr in v] | ||
| for k, v in report.rule_reports.items() | ||
| }, | ||
| search_and_apply_time=report.search_and_apply_time, | ||
| merge_time=report.merge_time, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class PrettyIterationReport: | ||
| rule_set_report: PrettyRuleSetReport | ||
| rebuild_time: timedelta | ||
|
|
||
| @classmethod | ||
| def from_bindings(cls, report: bindings.IterationReport, translate_key: callable) -> PrettyIterationReport: | ||
| return cls( | ||
| rule_set_report=PrettyRuleSetReport.from_bindings(report.rule_set_report, translate_key), | ||
| rebuild_time=report.rebuild_time, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class PrettyRunReport: | ||
| """Python-friendly wrapper around bindings.RunReport.""" | ||
|
|
||
| iterations: list[PrettyIterationReport] | ||
| updated: bool | ||
| search_and_apply_time_per_rule: dict[str, timedelta] | ||
| num_matches_per_rule: dict[str, int] | ||
| search_and_apply_time_per_ruleset: dict[str, timedelta] | ||
| merge_time_per_ruleset: dict[str, timedelta] | ||
| rebuild_time_per_ruleset: dict[str, timedelta] | ||
|
|
||
| @classmethod | ||
| def from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> PrettyRunReport: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we make these private, prefixing with |
||
| return cls( | ||
| iterations=[PrettyIterationReport.from_bindings(it, state.translate_rule_key) for it in report.iterations], | ||
| updated=report.updated, | ||
| search_and_apply_time_per_rule={ | ||
| state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items() | ||
| }, | ||
| num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()}, | ||
| search_and_apply_time_per_ruleset={ | ||
| state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_ruleset.items() | ||
| }, | ||
| merge_time_per_ruleset={state.translate_rule_key(k): v for k, v in report.merge_time_per_ruleset.items()}, | ||
| rebuild_time_per_ruleset={ | ||
| state.translate_rule_key(k): v for k, v in report.rebuild_time_per_ruleset.items() | ||
| }, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,200 @@ | ||
| # mypy: disable-error-code="empty-body" | ||
| from __future__ import annotations | ||
|
|
||
| from datetime import timedelta | ||
|
|
||
| from egglog import * | ||
|
|
||
|
|
||
| class TestPrettyRunReport: | ||
| def _setup_simple_egraph(self): | ||
| egraph = EGraph() | ||
|
|
||
| class Num(Expr): | ||
| def __init__(self, n: i64Like) -> None: ... | ||
| def __add__(self, other: Num) -> Num: ... | ||
|
|
||
| x, y = vars_("x y", Num) | ||
| egraph.register(rewrite(x + y).to(y + x)) | ||
| egraph.register(Num(1) + Num(2)) | ||
| return egraph | ||
|
|
||
| def test_run_returns_pretty_report(self): | ||
| egraph = self._setup_simple_egraph() | ||
| report = egraph.run(10) | ||
| assert type(report).__name__ == "PrettyRunReport" | ||
|
|
||
| def test_stats_returns_pretty_report(self): | ||
| egraph = self._setup_simple_egraph() | ||
| egraph.run(10) | ||
| report = egraph.stats() | ||
| assert type(report).__name__ == "PrettyRunReport" | ||
|
|
||
| def test_rule_names_translated_in_top_level_dicts(self): | ||
| egraph = self._setup_simple_egraph() | ||
| report = egraph.run(10) | ||
|
|
||
| for key in report.search_and_apply_time_per_rule: | ||
| assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" | ||
| assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}" | ||
|
|
||
| for key in report.num_matches_per_rule: | ||
| assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" | ||
| assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}" | ||
|
|
||
| def test_rule_names_translated_in_iterations(self): | ||
| egraph = self._setup_simple_egraph() | ||
| report = egraph.run(10) | ||
|
|
||
| assert len(report.iterations) > 0 | ||
| for iteration in report.iterations: | ||
| for key in iteration.rule_set_report.rule_reports: | ||
| assert "__main__" not in key, f"Iteration rule key not translated: {key}" | ||
| assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" | ||
|
|
||
| def test_updated_field(self): | ||
| egraph = self._setup_simple_egraph() | ||
| report = egraph.run(10) | ||
| assert isinstance(report.updated, bool) | ||
| assert report.updated is True | ||
|
|
||
| def test_num_matches(self): | ||
| egraph = self._setup_simple_egraph() | ||
| report = egraph.run(10) | ||
|
|
||
| total_matches = sum(report.num_matches_per_rule.values()) | ||
| assert total_matches > 0 | ||
|
|
||
| def test_timedelta_types(self): | ||
| egraph = self._setup_simple_egraph() | ||
| report = egraph.run(10) | ||
|
|
||
| for v in report.search_and_apply_time_per_rule.values(): | ||
| assert isinstance(v, timedelta) | ||
| for v in report.search_and_apply_time_per_ruleset.values(): | ||
| assert isinstance(v, timedelta) | ||
| for v in report.merge_time_per_ruleset.values(): | ||
| assert isinstance(v, timedelta) | ||
| for v in report.rebuild_time_per_ruleset.values(): | ||
| assert isinstance(v, timedelta) | ||
|
|
||
| def test_iteration_reports_are_pretty(self): | ||
| egraph = self._setup_simple_egraph() | ||
| report = egraph.run(10) | ||
|
|
||
| for it in report.iterations: | ||
| assert type(it).__name__ == "PrettyIterationReport" | ||
| assert type(it.rule_set_report).__name__ == "PrettyRuleSetReport" | ||
| for rule_reports in it.rule_set_report.rule_reports.values(): | ||
| for rr in rule_reports: | ||
| assert type(rr).__name__ == "PrettyRuleReport" | ||
|
|
||
| def test_str_no_egglog_sexprs(self): | ||
| egraph = self._setup_simple_egraph() | ||
| report = egraph.run(10) | ||
| output = str(report) | ||
|
|
||
| assert "(rewrite" not in output, f"str() still contains egglog s-expressions:\n{output}" | ||
| assert "__main__" not in output, f"str() still contains mangled names:\n{output}" | ||
|
|
||
| def test_multiple_rules(self): | ||
| egraph = EGraph() | ||
|
|
||
| class Math(Expr): | ||
| def __init__(self, value: i64Like) -> None: ... | ||
| def __add__(self, other: Math) -> Math: ... | ||
| def __mul__(self, other: Math) -> Math: ... | ||
|
|
||
| a, b = vars_("a b", Math) | ||
| egraph.register( | ||
| rewrite(a + b).to(b + a), | ||
| rewrite(a * b).to(b * a), | ||
| ) | ||
| egraph.register(Math(1) + Math(2), Math(3) * Math(4)) | ||
| report = egraph.run(10) | ||
|
|
||
| # should have two distinct translated rule keys | ||
| rule_keys = list(report.search_and_apply_time_per_rule.keys()) | ||
| assert len(rule_keys) == 2 | ||
| for key in rule_keys: | ||
| assert "__main__" not in key, f"Key not translated: {key}" | ||
|
|
||
| def test_empty_run(self): | ||
| egraph = EGraph() | ||
| report = egraph.run(1) | ||
| assert type(report).__name__ == "PrettyRunReport" | ||
| assert isinstance(report.updated, bool) | ||
|
|
||
| def test_named_rule(self): | ||
| egraph = EGraph() | ||
|
|
||
| class Num(Expr): | ||
| def __init__(self, n: i64Like) -> None: ... | ||
| def __add__(self, other: Num) -> Num: ... | ||
|
|
||
| x, y = vars_("x y", Num) | ||
| egraph.register(rule(x + y, name="comm").then(union(x + y).with_(y + x))) | ||
| egraph.register(Num(1) + Num(2)) | ||
| report = egraph.run(10) | ||
|
|
||
| output = str(report) | ||
| assert "__main__" not in output, f"str() still contains mangled names:\n{output}" | ||
|
|
||
| def test_unnamed_rule_decl(self): | ||
| egraph = EGraph() | ||
|
|
||
| class Num(Expr): | ||
| def __init__(self, n: i64Like) -> None: ... | ||
| def __add__(self, other: Num) -> Num: ... | ||
|
|
||
| x, y = vars_("x y", Num) | ||
| egraph.register(rule(x + y).then(union(x + y).with_(y + x))) | ||
| egraph.register(Num(1) + Num(2)) | ||
| report = egraph.run(10) | ||
|
|
||
| output = str(report) | ||
| assert "__main__" not in output, f"Unnamed RuleDecl key not translated:\n{output}" | ||
| # Should contain Python rule() syntax somewhere in the keys | ||
| rule_keys = list(report.search_and_apply_time_per_rule.keys()) | ||
| assert len(rule_keys) > 0 | ||
| for key in rule_keys: | ||
| assert "__main__" not in key, f"RuleDecl key not translated: {key}" | ||
|
|
||
| def test_birewrite_decl(self): | ||
| egraph = EGraph() | ||
|
|
||
| class Num(Expr): | ||
| def __init__(self, n: i64Like) -> None: ... | ||
| def __add__(self, other: Num) -> Num: ... | ||
| def __mul__(self, other: Num) -> Num: ... | ||
|
|
||
| x, y = vars_("x y", Num) | ||
| egraph.register(birewrite(x + y).to(y + x)) | ||
| egraph.register(Num(1) + Num(2)) | ||
| report = egraph.run(10) | ||
|
|
||
| output = str(report) | ||
| assert "__main__" not in output, f"BiRewriteDecl key not translated:\n{output}" | ||
| rule_keys = list(report.search_and_apply_time_per_rule.keys()) | ||
| assert len(rule_keys) > 0 | ||
| for key in rule_keys: | ||
| assert "__main__" not in key, f"BiRewriteDecl key not translated: {key}" | ||
| assert "birewrite" in key, f"Expected birewrite() syntax, got: {key}" | ||
|
|
||
| def test_rewrite_decl(self): | ||
| egraph = EGraph() | ||
|
|
||
| class Num(Expr): | ||
| def __init__(self, n: i64Like) -> None: ... | ||
| def __add__(self, other: Num) -> Num: ... | ||
|
|
||
| x, y = vars_("x y", Num) | ||
| egraph.register(rewrite(x + y).to(y + x)) | ||
| egraph.register(Num(1) + Num(2)) | ||
| report = egraph.run(10) | ||
|
|
||
| rule_keys = list(report.search_and_apply_time_per_rule.keys()) | ||
| assert len(rule_keys) == 1 | ||
| key = rule_keys[0] | ||
| assert "rewrite" in key, f"Expected rewrite() syntax, got: {key}" | ||
| assert "__main__" not in key, f"RewriteDecl key not translated: {key}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can just call it
RunReport. I have a few places I use the same name in the high level in bindings, such asEGraphitself. Since the hope is that users won't ever have to import frombindings, this shouldn't clobber the name.