Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7b5797a
Save local changes
FTRobbin Apr 24, 2025
7ec84fd
Things are starting to blow up
FTRobbin Apr 28, 2025
635a2b7
Threading ExtractorView
FTRobbin May 1, 2025
1ab4bdd
Extraction core logic
FTRobbin May 6, 2025
886ba3e
Reconstruction working
FTRobbin May 8, 2025
90879c9
Output for extracting a single best
FTRobbin May 8, 2025
b220785
Hacks for testing
FTRobbin May 9, 2025
f847d16
Test passing
FTRobbin May 9, 2025
49757a3
Fixing rb
FTRobbin May 9, 2025
feb2d1f
Fixed unstable-fn extraction
FTRobbin May 9, 2025
77b3e23
Nits
FTRobbin May 9, 2025
2e756a6
Support extract variants
FTRobbin May 12, 2025
375ca0e
Minor
FTRobbin May 12, 2025
df6c6e5
Handle unextractable terms
FTRobbin May 12, 2025
0cae876
Bad version
Alex-Fischman May 14, 2025
3b0d33b
Merge pull request #567 from Alex-Fischman/no_action_extraction
Alex-Fischman May 14, 2025
01c4f00
extraction in the front end for the new backend
FTRobbin May 21, 2025
2a80d81
Rip eval_resolved_expr_old
FTRobbin May 21, 2025
44be98e
Rename query-extract to extract in tests
FTRobbin May 21, 2025
2372f48
Minor
FTRobbin May 21, 2025
ee448d0
Fix test syntax :variant
FTRobbin May 21, 2025
a2865fc
Handle unextractable terms
FTRobbin May 22, 2025
5a0ca6b
Fix Error
FTRobbin May 22, 2025
7d6106e
Rip query-extract
FTRobbin May 22, 2025
2bcad51
Nits
FTRobbin May 22, 2025
d443765
Address Alex's comments
FTRobbin May 22, 2025
6f8f326
Addressed Alex's and Oliver's comments
FTRobbin May 22, 2025
55cf4e2
Refactor the reconstruction round
FTRobbin May 22, 2025
e85ec93
Handling term cost = any of its subterms' cost in extraction and test
FTRobbin May 22, 2025
3ddb488
Remove simplify
FTRobbin May 22, 2025
42b86ae
Tried to fix output with the new extraction
FTRobbin May 22, 2025
18f9461
fmt
FTRobbin May 22, 2025
d7d4617
Add doc to extract_best_with_sort
FTRobbin May 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 186 additions & 176 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ thiserror = "1"

# Backend
# TODO: move egglog-backend repo into core egglog repo
core-relations = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "e1c938d" }
egglog-bridge = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "e1c938d" }
numeric-id = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "e1c938d" }
core-relations = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "4f52312" }
egglog-bridge = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "4f52312" }
numeric-id = { git = "https://github.com/egraphs-good/egglog-backend.git", rev = "4f52312" }
Comment thread
Alex-Fischman marked this conversation as resolved.

# Need to add "js" feature for "graphviz-rust" to work in wasm
getrandom = { version = "0.2.10", optional = true, features = ["js"] }
Expand Down
56 changes: 1 addition & 55 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::core::{
use crate::{typechecking::FuncType, *};
use typechecking::TypeError;

use crate::{ast::Literal, core::ResolvedCall, ExtractReport, Value};
use crate::{ast::Literal, core::ResolvedCall, Value};

struct ActionCompiler<'a> {
types: &'a IndexMap<Symbol, ArcSort>,
Expand All @@ -24,11 +24,6 @@ impl ActionCompiler<'_> {
self.do_atom_term(at);
self.locals.insert(v.clone());
}
GenericCoreAction::Extract(_ann, e, b) => {
let sort = self.do_atom_term(e);
self.do_atom_term(b);
self.instructions.push(Instruction::Extract(2, sort));
}
GenericCoreAction::Set(_ann, f, args, e) => {
let ResolvedCall::Func(func) = f else {
panic!("Cannot set primitive- should have been caught by typechecking!!!")
Expand Down Expand Up @@ -130,10 +125,6 @@ enum Instruction {
Set(Symbol),
/// Union the last `n` values on the stack.
Union(usize, ArcSort),
/// Extract the best expression. `n` is always 2.
/// The first value on the stack is the expression to extract,
/// and the second value is the number of variants to extract.
Extract(usize, ArcSort),
/// Panic with the given message.
Panic(String),
}
Expand Down Expand Up @@ -330,51 +321,6 @@ impl EGraph {
});
stack.truncate(new_len);
}
Instruction::Extract(arity, sort) => {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
let new_len = stack.len() - arity;
let mut termdag = TermDag::default();

let variants = values[1].bits as i64;
if variants == 0 {
let (cost, term) = self.extract(values[0], &mut termdag, sort)?;
// dont turn termdag into a string if we have messages disabled for performance reasons
if self.messages_enabled() {
let extracted = termdag.to_string(&term);
log::info!("extracted with cost {cost}: {extracted}");
self.print_msg(extracted);
}
self.extract_report = Some(ExtractReport::Best {
termdag,
cost,
term,
});
} else {
if variants < 0 {
panic!("Cannot extract negative number of variants");
}
let terms =
self.extract_variants(sort, values[0], variants as usize, &mut termdag);
// Same as above, avoid turning termdag into a string if we have messages disabled for performance
if self.messages_enabled() {
log::info!("extracted variants:");
let mut msg = String::default();
msg += "(\n";
assert!(!terms.is_empty());
for expr in &terms {
let str = termdag.to_string(expr);
log::info!(" {str}");
msg += &format!(" {str}\n");
}
msg += ")";
self.print_msg(msg);
}
self.extract_report = Some(ExtractReport::Variants { termdag, terms });
}

stack.truncate(new_len);
}
Instruction::Panic(msg) => panic!("Panic: {msg}"),
Instruction::Literal(lit) => match lit {
Literal::Int(i) => stack.push(Value::from(*i)),
Expand Down
91 changes: 1 addition & 90 deletions src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,71 +140,13 @@ pub(crate) fn desugar_command(
vec![NCommand::UnstableCombinedRuleset(name, subrulesets)]
}
Command::Action(action) => vec![NCommand::CoreAction(action)],
Command::Simplify {
span,
expr,
schedule,
} => desugar_simplify(&expr, &schedule, span, parser),
Command::RunSchedule(sched) => {
vec![NCommand::RunSchedule(sched.clone())]
}
Command::PrintOverallStatistics => {
vec![NCommand::PrintOverallStatistics]
}
Command::QueryExtract {
span,
variants,
expr,
} => {
let variants = Expr::Lit(span.clone(), Literal::Int(variants.try_into().unwrap()));
if let Expr::Var(..) = expr {
// (extract {v} {variants})
vec![NCommand::CoreAction(Action::Extract(
span.clone(),
expr,
variants,
))]
} else {
// (check {expr})
// (ruleset {fresh_ruleset})
// (rule ((= {fresh} {expr}))
// ((extract {fresh} {variants}))
// :ruleset {fresh_ruleset})
// (run {fresh_ruleset} 1)
let fresh = parser.symbol_gen.fresh(&"desugar_qextract_var".into());
let fresh_ruleset = parser.symbol_gen.fresh(&"desugar_qextract_ruleset".into());
let fresh_rulename = parser.symbol_gen.fresh(&"desugar_qextract_rulename".into());
let rule = Rule {
span: span.clone(),
body: vec![Fact::Eq(
span.clone(),
Expr::Var(span.clone(), fresh),
expr.clone(),
)],
head: Actions::singleton(Action::Extract(
span.clone(),
Expr::Var(span.clone(), fresh),
variants,
)),
};
vec![
NCommand::Check(span.clone(), vec![Fact::Fact(expr.clone())]),
NCommand::AddRuleset(fresh_ruleset),
NCommand::NormRule {
name: fresh_rulename,
ruleset: fresh_ruleset,
rule,
},
NCommand::RunSchedule(Schedule::Run(
span.clone(),
RunConfig {
ruleset: fresh_ruleset,
until: None,
},
)),
]
}
}
Command::Extract(span, expr, variants) => vec![NCommand::Extract(span, expr, variants)],
Command::Check(span, facts) => vec![NCommand::Check(span, facts)],
Command::PrintFunction(span, symbol, size) => {
vec![NCommand::PrintTable(span, symbol, size)]
Expand Down Expand Up @@ -320,37 +262,6 @@ fn desugar_birewrite(ruleset: Symbol, name: Symbol, rewrite: &Rewrite) -> Vec<NC
.collect()
}

fn desugar_simplify(
expr: &Expr,
schedule: &Schedule,
span: Span,
parser: &mut Parser,
) -> Vec<NCommand> {
let mut res = vec![NCommand::Push(1)];
let lhs = parser.symbol_gen.fresh(&"desugar_simplify".into());
res.push(NCommand::CoreAction(Action::Let(
span.clone(),
lhs,
expr.clone(),
)));
res.push(NCommand::RunSchedule(schedule.clone()));
res.extend(
desugar_command(
Command::QueryExtract {
Comment thread
FTRobbin marked this conversation as resolved.
span: span.clone(),
variants: 0,
expr: Expr::Var(span.clone(), lhs),
},
parser,
false,
)
.unwrap(),
);

res.push(NCommand::Pop(span, 1));
res
}

pub fn rule_name<Head, Leaf>(command: &GenericCommand<Head, Leaf>) -> Symbol
where
Head: Clone + Display,
Expand Down
87 changes: 16 additions & 71 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ where
rule: GenericRule<Head, Leaf>,
},
CoreAction(GenericAction<Head, Leaf>),
Extract(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
RunSchedule(GenericSchedule<Head, Leaf>),
PrintOverallStatistics,
Check(Span, Vec<GenericFact<Head, Leaf>>),
Expand Down Expand Up @@ -138,6 +139,9 @@ where
GenericNCommand::RunSchedule(schedule) => GenericCommand::RunSchedule(schedule.clone()),
GenericNCommand::PrintOverallStatistics => GenericCommand::PrintOverallStatistics,
GenericNCommand::CoreAction(action) => GenericCommand::Action(action.clone()),
GenericNCommand::Extract(span, expr, variants) => {
GenericCommand::Extract(span.clone(), expr.clone(), variants.clone())
}
GenericNCommand::Check(span, facts) => {
GenericCommand::Check(span.clone(), facts.clone())
}
Expand Down Expand Up @@ -196,6 +200,9 @@ where
GenericNCommand::CoreAction(action) => {
GenericNCommand::CoreAction(action.visit_exprs(f))
}
GenericNCommand::Extract(span, expr, variants) => {
GenericNCommand::Extract(span, expr.visit_exprs(f), variants.visit_exprs(f))
}
GenericNCommand::Check(span, facts) => GenericNCommand::Check(
span,
facts.into_iter().map(|fact| fact.visit_exprs(f)).collect(),
Expand Down Expand Up @@ -560,6 +567,12 @@ where
/// (let xplusone (Add (Var "x") (Num 1)))
/// ```
Action(GenericAction<Head, Leaf>),
/// `extract` a datatype from the egraph, choosing
/// the smallest representative.
/// By default, each constructor costs 1 to extract
/// (common subexpressions are not shared in the cost
/// model).
Extract(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
/// Runs a [`Schedule`], which specifies
/// rulesets and the number of times to run them.
///
Expand All @@ -578,42 +591,6 @@ where
/// Print runtime statistics about rules
/// and rulesets so far.
PrintOverallStatistics,
// TODO provide simplify docs
Simplify {
span: Span,
expr: GenericExpr<Head, Leaf>,
schedule: GenericSchedule<Head, Leaf>,
},
/// The `query-extract` command runs a query,
/// extracting the result for each match that it finds.
/// For a simpler extraction command, use [`Action::Extract`] instead.
///
/// Example:
/// ```text
/// (query-extract (Add a b))
/// ```
///
/// Extracts every `Add` term in the database, once
/// for each class of equivalent `a` and `b`.
///
/// The resulting datatype is chosen from the egraph
/// as the smallest term by size (taking into account
/// the `:cost` annotations for each constructor).
/// This cost does *not* take into account common sub-expressions.
/// For example, the following term has cost 5:
/// ```text
/// (Add
/// (Num 1)
/// (Num 1))
/// ```
///
/// Under the hood, this command is implemented with the [`EGraph::extract`]
/// function.
QueryExtract {
span: Span,
variants: usize,
expr: GenericExpr<Head, Leaf>,
},
/// The `check` command checks that the given facts
/// match at least once in the current database.
/// The list of facts is matched in the same way a [`Command::Rule`] is matched.
Expand Down Expand Up @@ -689,6 +666,9 @@ where
variants,
} => write!(f, "(datatype {name} {})", ListDisplay(variants, " ")),
GenericCommand::Action(a) => write!(f, "{a}"),
GenericCommand::Extract(_span, expr, variants) => {
write!(f, "(extract {expr} {variants})")
}
GenericCommand::Sort(_span, name, None) => write!(f, "(sort {name})"),
GenericCommand::Sort(_span, name, Some((name2, args))) => {
write!(f, "(sort {name} ({name2} {}))", ListDisplay(args, " "))
Expand Down Expand Up @@ -745,13 +725,6 @@ where
} => rule.fmt_with_ruleset(f, *ruleset, *name),
GenericCommand::RunSchedule(sched) => write!(f, "(run-schedule {sched})"),
GenericCommand::PrintOverallStatistics => write!(f, "(print-stats)"),
GenericCommand::QueryExtract {
span: _,
variants,
expr,
} => {
write!(f, "(query-extract :variants {variants} {expr})")
}
GenericCommand::Check(_ann, facts) => {
write!(f, "(check {})", ListDisplay(facts, "\n"))
}
Expand All @@ -775,11 +748,6 @@ where
} => write!(f, "(output {file:?} {})", ListDisplay(exprs, " ")),
GenericCommand::Fail(_span, cmd) => write!(f, "(fail {cmd})"),
GenericCommand::Include(_span, file) => write!(f, "(include {file:?})"),
GenericCommand::Simplify {
span: _,
expr,
schedule,
} => write!(f, "(simplify {schedule} {expr})"),
GenericCommand::Datatypes { span: _, datatypes } => {
let datatypes: Vec<_> = datatypes
.iter()
Expand Down Expand Up @@ -1226,15 +1194,6 @@ where
/// (extract (Num 2)); Extracts Num 1
/// ```
Union(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
/// `extract` a datatype from the egraph, choosing
/// the smallest representative.
/// By default, each constructor costs 1 to extract
/// (common subexpressions are not shared in the cost
/// model).
/// The second argument is the number of variants to
/// extract, picking different terms in the
/// same equivalence class.
Extract(Span, GenericExpr<Head, Leaf>, GenericExpr<Head, Leaf>),
Panic(Span, String),
Expr(Span, GenericExpr<Head, Leaf>),
// If(Expr, Action, Action),
Expand Down Expand Up @@ -1299,9 +1258,6 @@ where
};
write!(f, "({change} ({lhs} {}))", ListDisplay(args, " "))
}
GenericAction::Extract(_ann, expr, variants) => {
write!(f, "(extract {expr} {variants})")
}
GenericAction::Panic(_ann, msg) => write!(f, "(panic {msg:?})"),
GenericAction::Expr(_ann, e) => write!(f, "{e}"),
}
Expand Down Expand Up @@ -1340,9 +1296,6 @@ where
GenericAction::Union(span, lhs, rhs) => {
GenericAction::Union(span.clone(), f(lhs), f(rhs))
}
GenericAction::Extract(span, expr, variants) => {
GenericAction::Extract(span.clone(), f(expr), f(variants))
}
GenericAction::Panic(span, msg) => GenericAction::Panic(span.clone(), msg.clone()),
GenericAction::Expr(span, e) => GenericAction::Expr(span.clone(), f(e)),
}
Expand Down Expand Up @@ -1372,9 +1325,6 @@ where
GenericAction::Union(span, lhs, rhs) => {
GenericAction::Union(span, lhs.visit_exprs(f), rhs.visit_exprs(f))
}
GenericAction::Extract(span, expr, variants) => {
GenericAction::Extract(span, expr.visit_exprs(f), variants.visit_exprs(f))
}
GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
GenericAction::Expr(span, e) => GenericAction::Expr(span, e.visit_exprs(f)),
}
Expand Down Expand Up @@ -1416,11 +1366,6 @@ where
let rhs = rhs.subst_leaf(&mut fvar_expr!());
GenericAction::Union(span, lhs, rhs)
}
GenericAction::Extract(span, expr, variants) => {
let expr = expr.subst_leaf(&mut fvar_expr!());
let variants = variants.subst_leaf(&mut fvar_expr!());
GenericAction::Extract(span, expr, variants)
}
GenericAction::Panic(span, msg) => GenericAction::Panic(span, msg.clone()),
GenericAction::Expr(span, e) => {
GenericAction::Expr(span, e.subst_leaf(&mut fvar_expr!()))
Expand Down
Loading
Loading