Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .egraph_state import *
from .ipython_magic import IN_IPYTHON
from .pretty import pretty_decl
from .run_report import PrettyRunReport
from .runtime import *
from .thunk import *

Expand Down Expand Up @@ -953,36 +954,36 @@ def output(self) -> None:
raise NotImplementedError(msg)

@overload
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ...
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> PrettyRunReport: ...
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just call it RunReport. I have a few places I use the same name in the high level in bindings, such as EGraph itself. Since the hope is that users won't ever have to import from bindings, this shouldn't clobber the name.


@overload
def run(self, schedule: Schedule, /) -> bindings.RunReport: ...
def run(self, schedule: Schedule, /) -> PrettyRunReport: ...

@_TRACER.start_as_current_span("run")
def run(
self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
) -> bindings.RunReport:
) -> PrettyRunReport:
"""
Run the egraph until the given limit or until the given facts are true.
"""
if isinstance(limit_or_schedule, int):
limit_or_schedule = run(ruleset, *until) * limit_or_schedule
return self._run_schedule(limit_or_schedule)

def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
def _run_schedule(self, schedule: Schedule) -> PrettyRunReport:
self._add_decls(schedule)
cmd = self._state.run_schedule_to_egg(schedule.schedule)
(command_output,) = self._run_program(cmd)
assert isinstance(command_output, bindings.RunScheduleOutput)
return command_output.report
return PrettyRunReport.from_bindings(command_output.report, self._state)

def stats(self) -> bindings.RunReport:
def stats(self) -> PrettyRunReport:
"""
Returns the overall run report for the egraph.
"""
(output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None))
assert isinstance(output, bindings.OverallStatistics)
return output.report
return PrettyRunReport.from_bindings(output.report, self._state)

def check_bool(self, *facts: FactLike) -> bool:
"""
Expand Down
26 changes: 20 additions & 6 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class EGraphState:
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)

egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict)

# Cache of egg expressions for converting to egg
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)

Expand Down Expand Up @@ -247,6 +249,14 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
case _:
assert_never(schedule)

def translate_rule_key(self, egglog_key: str) -> str:
"""
Translate an egglog rule name to its Python representation.
"""
if egglog_key in self.egg_rule_to_command_decl:
return pretty_decl(self.__egg_decls__, self.egg_rule_to_command_decl[egglog_key])
return egglog_key
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any cases where it won't find the saved command? If possible seems better to raise an exception instead of falling back.


def ruleset_to_egg(self, ident: Ident) -> None:
"""
Registers a ruleset if it's not already registered.
Expand Down Expand Up @@ -289,13 +299,15 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
self._expr_to_egg(rhs),
[self.fact_to_egg(c) for c in conditions],
)
return (
bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
if isinstance(cmd, RewriteDecl)
else bindings.BiRewriteCommand(str(ruleset), rewrite)
)
if isinstance(cmd, RewriteDecl):
egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
else:
egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite)

self.egg_rule_to_command_decl[str(egg_cmd)] = cmd
return egg_cmd
case RuleDecl(head, body, name):
return bindings.RuleCommand(
egg_cmd = bindings.RuleCommand(
bindings.Rule(
span(),
[self.action_to_egg(a) for a in head],
Expand All @@ -304,6 +316,8 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
str(ruleset),
)
)
self.egg_rule_to_command_decl[str(egg_cmd)] = cmd
return egg_cmd
# TODO: Replace with just constants value and looking at REF of function
case DefaultRewriteDecl(ref, expr, subsume):
sig = self.__egg_decls__.get_callable_decl(ref).signature
Expand Down
86 changes: 86 additions & 0 deletions python/egglog/run_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

from dataclasses import dataclass
from datetime import timedelta

from . import bindings
from .egraph_state import EGraphState


@dataclass
class PrettyRuleReport:
plan: bindings.Plan | None
search_and_apply_time: timedelta
num_matches: int

@classmethod
def from_bindings(cls, report: bindings.RuleReport) -> PrettyRuleReport:
return cls(
plan=report.plan,
search_and_apply_time=report.search_and_apply_time,
num_matches=report.num_matches,
)


@dataclass
class PrettyRuleSetReport:
changed: bool
rule_reports: dict[str, list[PrettyRuleReport]]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we try storing them as actual rewrites/rules here instead of strings? That would allow users at runtime for example to extract the rule, run it on something, or compare equality without having to compare string equality. Also just in others forms I usually try to return runtime objects instead of strings.

This would also impact how they are stored in e-graph state.

Then also the pretty print won't show them as strings.

search_and_apply_time: timedelta
merge_time: timedelta

@classmethod
def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) -> PrettyRuleSetReport:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this callable be more strongly typed?

return cls(
changed=report.changed,
rule_reports={
translate_key(k): [PrettyRuleReport.from_bindings(rr) for rr in v]
for k, v in report.rule_reports.items()
},
search_and_apply_time=report.search_and_apply_time,
merge_time=report.merge_time,
)


@dataclass
class PrettyIterationReport:
rule_set_report: PrettyRuleSetReport
rebuild_time: timedelta

@classmethod
def from_bindings(cls, report: bindings.IterationReport, translate_key: callable) -> PrettyIterationReport:
return cls(
rule_set_report=PrettyRuleSetReport.from_bindings(report.rule_set_report, translate_key),
rebuild_time=report.rebuild_time,
)


@dataclass
class PrettyRunReport:
"""Python-friendly wrapper around bindings.RunReport."""

iterations: list[PrettyIterationReport]
updated: bool
search_and_apply_time_per_rule: dict[str, timedelta]
num_matches_per_rule: dict[str, int]
search_and_apply_time_per_ruleset: dict[str, timedelta]
merge_time_per_ruleset: dict[str, timedelta]
rebuild_time_per_ruleset: dict[str, timedelta]

@classmethod
def from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> PrettyRunReport:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make these private, prefixing with _? I don't imagine that users would need to be able to translate between bindings and this.

return cls(
iterations=[PrettyIterationReport.from_bindings(it, state.translate_rule_key) for it in report.iterations],
updated=report.updated,
search_and_apply_time_per_rule={
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items()
},
num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()},
search_and_apply_time_per_ruleset={
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_ruleset.items()
},
merge_time_per_ruleset={state.translate_rule_key(k): v for k, v in report.merge_time_per_ruleset.items()},
rebuild_time_per_ruleset={
state.translate_rule_key(k): v for k, v in report.rebuild_time_per_ruleset.items()
},
)
200 changes: 200 additions & 0 deletions python/tests/test_run_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# mypy: disable-error-code="empty-body"
from __future__ import annotations

from datetime import timedelta

from egglog import *


class TestPrettyRunReport:
def _setup_simple_egraph(self):
egraph = EGraph()

class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...

x, y = vars_("x y", Num)
egraph.register(rewrite(x + y).to(y + x))
egraph.register(Num(1) + Num(2))
return egraph

def test_run_returns_pretty_report(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
assert type(report).__name__ == "PrettyRunReport"

def test_stats_returns_pretty_report(self):
egraph = self._setup_simple_egraph()
egraph.run(10)
report = egraph.stats()
assert type(report).__name__ == "PrettyRunReport"

def test_rule_names_translated_in_top_level_dicts(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)

for key in report.search_and_apply_time_per_rule:
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"
assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}"

for key in report.num_matches_per_rule:
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"
assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}"

def test_rule_names_translated_in_iterations(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)

assert len(report.iterations) > 0
for iteration in report.iterations:
for key in iteration.rule_set_report.rule_reports:
assert "__main__" not in key, f"Iteration rule key not translated: {key}"
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"

def test_updated_field(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
assert isinstance(report.updated, bool)
assert report.updated is True

def test_num_matches(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)

total_matches = sum(report.num_matches_per_rule.values())
assert total_matches > 0

def test_timedelta_types(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)

for v in report.search_and_apply_time_per_rule.values():
assert isinstance(v, timedelta)
for v in report.search_and_apply_time_per_ruleset.values():
assert isinstance(v, timedelta)
for v in report.merge_time_per_ruleset.values():
assert isinstance(v, timedelta)
for v in report.rebuild_time_per_ruleset.values():
assert isinstance(v, timedelta)

def test_iteration_reports_are_pretty(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)

for it in report.iterations:
assert type(it).__name__ == "PrettyIterationReport"
assert type(it.rule_set_report).__name__ == "PrettyRuleSetReport"
for rule_reports in it.rule_set_report.rule_reports.values():
for rr in rule_reports:
assert type(rr).__name__ == "PrettyRuleReport"

def test_str_no_egglog_sexprs(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
output = str(report)

assert "(rewrite" not in output, f"str() still contains egglog s-expressions:\n{output}"
assert "__main__" not in output, f"str() still contains mangled names:\n{output}"

def test_multiple_rules(self):
egraph = EGraph()

class Math(Expr):
def __init__(self, value: i64Like) -> None: ...
def __add__(self, other: Math) -> Math: ...
def __mul__(self, other: Math) -> Math: ...

a, b = vars_("a b", Math)
egraph.register(
rewrite(a + b).to(b + a),
rewrite(a * b).to(b * a),
)
egraph.register(Math(1) + Math(2), Math(3) * Math(4))
report = egraph.run(10)

# should have two distinct translated rule keys
rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) == 2
for key in rule_keys:
assert "__main__" not in key, f"Key not translated: {key}"

def test_empty_run(self):
egraph = EGraph()
report = egraph.run(1)
assert type(report).__name__ == "PrettyRunReport"
assert isinstance(report.updated, bool)

def test_named_rule(self):
egraph = EGraph()

class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...

x, y = vars_("x y", Num)
egraph.register(rule(x + y, name="comm").then(union(x + y).with_(y + x)))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)

output = str(report)
assert "__main__" not in output, f"str() still contains mangled names:\n{output}"

def test_unnamed_rule_decl(self):
egraph = EGraph()

class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...

x, y = vars_("x y", Num)
egraph.register(rule(x + y).then(union(x + y).with_(y + x)))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)

output = str(report)
assert "__main__" not in output, f"Unnamed RuleDecl key not translated:\n{output}"
# Should contain Python rule() syntax somewhere in the keys
rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) > 0
for key in rule_keys:
assert "__main__" not in key, f"RuleDecl key not translated: {key}"

def test_birewrite_decl(self):
egraph = EGraph()

class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...
def __mul__(self, other: Num) -> Num: ...

x, y = vars_("x y", Num)
egraph.register(birewrite(x + y).to(y + x))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)

output = str(report)
assert "__main__" not in output, f"BiRewriteDecl key not translated:\n{output}"
rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) > 0
for key in rule_keys:
assert "__main__" not in key, f"BiRewriteDecl key not translated: {key}"
assert "birewrite" in key, f"Expected birewrite() syntax, got: {key}"

def test_rewrite_decl(self):
egraph = EGraph()

class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...

x, y = vars_("x y", Num)
egraph.register(rewrite(x + y).to(y + x))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)

rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) == 1
key = rule_keys[0]
assert "rewrite" in key, f"Expected rewrite() syntax, got: {key}"
assert "__main__" not in key, f"RewriteDecl key not translated: {key}"
Loading