Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
793 changes: 439 additions & 354 deletions Cargo.lock

Large diffs are not rendered by default.

22 changes: 10 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ opentelemetry_sdk = "0.28"
# egglog-core-relations = { path = "../egg-smol/core-relations" }
# egglog-ast = { path = "../egg-smol/egglog-ast" }
# egglog-reports = { path = "../egg-smol/egglog-reports" }
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug", default-features = false }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }


egglog = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b", default-features = false }
egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-reports = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "main", default-features = false }
egraph-serialize = { version = "0.3", features = ["serde", "graphviz"] }
serde_json = "1"
Expand All @@ -52,11 +50,11 @@ base64 = "0.22.1"
# egglog-reports = { path = "../egg-smol/egglog-reports" }
# egglog-bridge = { path = "../egg-smol/egglog-bridge" }

egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }

# enable debug symbols for easier profiling
[profile.release]
Expand Down
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _This project uses semantic versioning_

## 13.1.0 (2026-03-25)

- Add Python-friendly `RunReport` wrapper that returns `CommandDecl` objects as rule keys instead of raw egglog s-expression strings, with pretty-printed Python syntax in `str()` output [#416](https://github.com/egraphs-good/egglog-python/pull/416)
- Improve high-level Python ergonomics and docs [#397](https://github.com/egraphs-good/egglog-python/pull/397)
- Add `EGraph.freeze()`, returning a `FrozenEGraph` snapshot that can be pretty-printed back into replayable high-level Python actions for debugging and inspection.
- Add a variadic `EGraph(*actions, seminaive=True, save_egglog_string=False)` constructor so actions can be registered at construction time, and export `ActionLike` from `egglog` for typing code that works with `EGraph.register(...)` and the constructor.
Expand Down
5 changes: 4 additions & 1 deletion python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,11 @@ class Rewrite:
lhs: _Expr
rhs: _Expr
conditions: list[_Fact]
name: str

def __new__(cls, span: _Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = ...) -> Rewrite: ...
def __new__(
cls, span: _Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = ..., name: str = ...
) -> Rewrite: ...

@final
class RunConfig:
Expand Down
18 changes: 9 additions & 9 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .egraph_state import *
from .ipython_magic import IN_IPYTHON
from .pretty import pretty_decl
from .run_report import RunReport
from .runtime import *
from .thunk import *

Comment thread
kaeun97 marked this conversation as resolved.
Expand Down Expand Up @@ -70,6 +71,7 @@
"GreedyDagCost",
"RewriteOrRule",
"Ruleset",
"RunReport",
"Schedule",
"_BirewriteBuilder",
"_EqBuilder",
Expand Down Expand Up @@ -953,36 +955,34 @@ def output(self) -> None:
raise NotImplementedError(msg)

@overload
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ...
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport: ...

@overload
def run(self, schedule: Schedule, /) -> bindings.RunReport: ...
def run(self, schedule: Schedule, /) -> RunReport: ...

@_TRACER.start_as_current_span("run")
def run(
self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
) -> bindings.RunReport:
def run(self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport:
"""
Run the egraph until the given limit or until the given facts are true.
"""
if isinstance(limit_or_schedule, int):
limit_or_schedule = run(ruleset, *until) * limit_or_schedule
return self._run_schedule(limit_or_schedule)

def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
def _run_schedule(self, schedule: Schedule) -> RunReport:
self._add_decls(schedule)
cmd = self._state.run_schedule_to_egg(schedule.schedule)
(command_output,) = self._run_program(cmd)
assert isinstance(command_output, bindings.RunScheduleOutput)
return command_output.report
return RunReport._from_bindings(command_output.report, self._state)

def stats(self) -> bindings.RunReport:
def stats(self) -> RunReport:
"""
Returns the overall run report for the egraph.
"""
(output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None))
assert isinstance(output, bindings.OverallStatistics)
return output.report
return RunReport._from_bindings(output.report, self._state)

def check_bool(self, *facts: FactLike) -> bool:
"""
Expand Down
30 changes: 24 additions & 6 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ class EGraphState:
# Counter for deterministic synthetic names assigned to unnamed functions.
unnamed_function_counter: int = 0

# Counter for numeric rule names
rule_name_counter: int = 0
# Mapping from numeric name (str) to command decl
rule_name_to_command_decl: dict[str, RuleDecl | BiRewriteDecl | RewriteDecl] = field(default_factory=dict)

def copy(self) -> EGraphState:
"""
Returns a copy of the state. The egraph reference is kept the same. Used for pushing/popping.
Expand All @@ -102,6 +107,8 @@ def copy(self) -> EGraphState:
cost_callables=self.cost_callables.copy(),
expr_to_let_counter=self.expr_to_let_counter,
unnamed_function_counter=self.unnamed_function_counter,
rule_name_counter=self.rule_name_counter,
rule_name_to_command_decl=self.rule_name_to_command_decl.copy(),
)

def _run_program(self, *commands: bindings._Command) -> list[bindings._CommandOutput]:
Expand Down Expand Up @@ -283,24 +290,35 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
return bindings.ActionCommand(action_egg)
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
self.type_ref_to_egg(tp)
name = str(self.rule_name_counter)
self.rule_name_counter += 1
Comment thread
kaeun97 marked this conversation as resolved.
rewrite = bindings.Rewrite(
span(),
self._expr_to_egg(lhs),
self._expr_to_egg(rhs),
[self.fact_to_egg(c) for c in conditions],
name,
)
return (
bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
if isinstance(cmd, RewriteDecl)
else bindings.BiRewriteCommand(str(ruleset), rewrite)
)
egg_cmd: bindings._Command
if isinstance(cmd, RewriteDecl):
self.rule_name_to_command_decl[name] = cmd
egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
else:
self.rule_name_to_command_decl[f"{name}=>"] = cmd
self.rule_name_to_command_decl[f"{name}<="] = cmd
egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite)
return egg_cmd
case RuleDecl(head, body, name):
if not name:
name = str(self.rule_name_counter)
self.rule_name_counter += 1
self.rule_name_to_command_decl[name] = cmd
return bindings.RuleCommand(
bindings.Rule(
span(),
[self.action_to_egg(a) for a in head],
[self.fact_to_egg(f) for f in body],
name or "",
name,
str(ruleset),
)
)
Expand Down
142 changes: 142 additions & 0 deletions python/egglog/run_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

from dataclasses import dataclass, field
from datetime import timedelta

from . import bindings
from .declarations import BiRewriteDecl, Declarations, RewriteDecl, RuleDecl
from .egraph_state import EGraphState
from .pretty import pretty_decl

RewriteOrRuleDecl = RuleDecl | BiRewriteDecl | RewriteDecl


def _format_rule_key(decls: Declarations, key: RewriteOrRuleDecl) -> str:
return pretty_decl(decls, key)


@dataclass
class RuleReport:
plan: bindings.Plan | None
search_and_apply_time: timedelta
num_matches: int

@classmethod
def _from_bindings(cls, report: bindings.RuleReport) -> RuleReport:
return cls(
plan=report.plan,
search_and_apply_time=report.search_and_apply_time,
num_matches=report.num_matches,
)


@dataclass
class RuleSetReport:
_decls: Declarations = field(repr=False)
changed: bool = False
rule_reports: dict[RewriteOrRuleDecl, list[RuleReport]] = field(default_factory=dict)
search_and_apply_time: timedelta = field(default_factory=timedelta)
merge_time: timedelta = field(default_factory=timedelta)

@classmethod
def _from_bindings(
cls, report: bindings.RuleSetReport, rule_map: dict[str, RewriteOrRuleDecl], decls: Declarations
) -> RuleSetReport:
rule_reports: dict[RewriteOrRuleDecl, list[RuleReport]] = {}
for k, v in report.rule_reports.items():
translated = rule_map[k]
reports = [RuleReport._from_bindings(rr) for rr in v]
if translated in rule_reports:
rule_reports[translated].extend(reports)
else:
rule_reports[translated] = reports
return cls(
_decls=decls,
changed=report.changed,
rule_reports=rule_reports,
search_and_apply_time=report.search_and_apply_time,
merge_time=report.merge_time,
)

def __repr__(self) -> str:
rule_reports_str = {_format_rule_key(self._decls, k): v for k, v in self.rule_reports.items()}
return (
f"RuleSetReport(changed={self.changed}, "
f"rule_reports={rule_reports_str}, "
f"search_and_apply_time={self.search_and_apply_time}, "
f"merge_time={self.merge_time})"
)


@dataclass
class IterationReport:
rule_set_report: RuleSetReport
rebuild_time: timedelta

@classmethod
def _from_bindings(
cls, report: bindings.IterationReport, rule_map: dict[str, RewriteOrRuleDecl], decls: Declarations
) -> IterationReport:
return cls(
rule_set_report=RuleSetReport._from_bindings(report.rule_set_report, rule_map, decls),
rebuild_time=report.rebuild_time,
)


@dataclass
class RunReport:
"""Python-friendly wrapper around bindings.RunReport."""

_decls: Declarations = field(repr=False)
iterations: list[IterationReport] = field(default_factory=list)
updated: bool = False
search_and_apply_time_per_rule: dict[RewriteOrRuleDecl, timedelta] = field(default_factory=dict)
num_matches_per_rule: dict[RewriteOrRuleDecl, int] = field(default_factory=dict)
search_and_apply_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)
merge_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)
rebuild_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)

def __repr__(self) -> str:
time_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.search_and_apply_time_per_rule.items()}
matches_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.num_matches_per_rule.items()}
return (
f"RunReport(iterations={self.iterations}, "
f"updated={self.updated}, "
f"search_and_apply_time_per_rule={time_per_rule}, "
f"num_matches_per_rule={matches_per_rule}, "
f"search_and_apply_time_per_ruleset={self.search_and_apply_time_per_ruleset}, "
f"merge_time_per_ruleset={self.merge_time_per_ruleset}, "
f"rebuild_time_per_ruleset={self.rebuild_time_per_ruleset})"
)

@classmethod
def _from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport:
rule_map = state.rule_name_to_command_decl
decls = state.__egg_decls__

search_and_apply_time_per_rule: dict[RewriteOrRuleDecl, timedelta] = {}
for k, v in report.search_and_apply_time_per_rule.items():
translated = rule_map[k]
if translated in search_and_apply_time_per_rule:
search_and_apply_time_per_rule[translated] += v
else:
search_and_apply_time_per_rule[translated] = v

num_matches_per_rule: dict[RewriteOrRuleDecl, int] = {}
for k, v in report.num_matches_per_rule.items():
translated = rule_map[k]
if translated in num_matches_per_rule:
num_matches_per_rule[translated] += v
else:
num_matches_per_rule[translated] = v

return cls(
_decls=decls,
iterations=[IterationReport._from_bindings(it, rule_map, decls) for it in report.iterations],
updated=report.updated,
search_and_apply_time_per_rule=search_and_apply_time_per_rule,
num_matches_per_rule=num_matches_per_rule,
search_and_apply_time_per_ruleset=report.search_and_apply_time_per_ruleset,
merge_time_per_ruleset=report.merge_time_per_ruleset,
rebuild_time_per_ruleset=report.rebuild_time_per_ruleset,
)
Loading
Loading