Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,28 +334,24 @@ impl EGraph {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
let new_len = stack.len() - arity;
let mut termdag = TermDag::default();

let variants = values[1].bits as i64;
if variants == 0 {
let (cost, term) = self.extract(values[0], &mut termdag, sort)?;
let (cost, term) = self.extract(values[0], sort)?;
let termdag = &self.termdag;
// dont turn termdag into a string if we have messages disabled for performance reasons
if self.messages_enabled() {
let extracted = termdag.to_string(&term);
log::info!("extracted with cost {cost}: {extracted}");
self.print_msg(extracted);
}
self.extract_report = Some(ExtractReport::Best {
termdag,
cost,
term,
});
self.extract_report = Some(ExtractReport::Best { cost, term });
} else {
if variants < 0 {
panic!("Cannot extract negative number of variants");
}
let terms =
self.extract_variants(sort, values[0], variants as usize, &mut termdag);
let terms = self.extract_variants(sort, values[0], variants as usize);
let termdag = &self.termdag;
// Same as above, avoid turning termdag into a string if we have messages disabled for performance
if self.messages_enabled() {
log::info!("extracted variants:");
Expand All @@ -370,7 +366,7 @@ impl EGraph {
msg += ")";
self.print_msg(msg);
}
self.extract_report = Some(ExtractReport::Variants { termdag, terms });
self.extract_report = Some(ExtractReport::Variants { terms });
}

stack.truncate(new_len);
Expand Down
170 changes: 104 additions & 66 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::util::HashMap;
use crate::{ArcSort, EGraph, Error, Function, HEntry, Id, Value};

pub type Cost = usize;
pub(crate) type CostMap = HashMap<Id, (Cost, Term)>;

#[derive(Debug)]
pub(crate) struct Node<'a> {
Expand All @@ -13,7 +14,7 @@ pub(crate) struct Node<'a> {
}

pub struct Extractor<'a> {
pub costs: HashMap<Id, (Cost, Term)>,
pub costs: CostMap,
ctors: Vec<Symbol>,
egraph: &'a EGraph,
}
Expand All @@ -37,78 +38,67 @@ impl EGraph {
/// let (_, extracted) = egraph.extract(value, &mut termdag, &sort).unwrap();
/// assert_eq!(termdag.to_string(&extracted), "(Add 1 1)");
/// ```
pub fn extract(
&self,
value: Value,
termdag: &mut TermDag,
arcsort: &ArcSort,
) -> Result<(Cost, Term), Error> {
let extractor = Extractor::new(self, termdag);
extractor.find_best(value, termdag, arcsort).ok_or_else(|| {
log::error!("No cost for {:?}", value);
for func in self.functions.values() {
for (inputs, output) in func.nodes.iter(false) {
if output.value == value {
log::error!("Found unextractable function: {:?}", func.decl.name);
log::error!("Inputs: {:?}", inputs);

assert_eq!(inputs.len(), func.schema.input.len());
log::error!(
"{:?}",
inputs
.iter()
.zip(&func.schema.input)
.map(|(input, sort)| extractor
.costs
.get(&extractor.egraph.find(sort, *input).bits))
.collect::<Vec<_>>()
);
pub fn extract(&mut self, value: Value, arcsort: &ArcSort) -> Result<(Cost, Term), Error> {
let mut cost_map = None;
let mut termdag = std::mem::take(&mut self.termdag);
if self.cost_cache.is_some() {
let cost_map_ts = std::mem::take(&mut self.cost_cache).unwrap();
if cost_map_ts.0 == self.timestamp {
cost_map = Some(cost_map_ts.1);
}
}
let extractor = Extractor::new(self, &mut termdag, cost_map);
let result = extractor
.find_best(value, &mut termdag, arcsort)
.ok_or_else(|| {
log::error!("No cost for {:?}", value);
for func in self.functions.values() {
for (inputs, output) in func.nodes.iter(false) {
if output.value == value {
log::error!("Found unextractable function: {:?}", func.decl.name);
log::error!("Inputs: {:?}", inputs);

assert_eq!(inputs.len(), func.schema.input.len());
log::error!(
"{:?}",
inputs
.iter()
.zip(&func.schema.input)
.map(|(input, sort)| extractor
.costs
.get(&extractor.egraph.find(sort, *input).bits))
.collect::<Vec<_>>()
);
}
}
}
}
Error::ExtractError(value)
})
Error::ExtractError(value)
});
self.cost_cache = Some((self.timestamp, extractor.cost_map()));
self.termdag = termdag;
result
}

pub fn extract_variants(
&mut self,
sort: &ArcSort,
value: Value,
limit: usize,
termdag: &mut TermDag,
) -> Vec<Term> {
let output_sort = sort.name();
let output_value = self.find(sort, value);
let ext = &Extractor::new(self, termdag);
ext.ctors
.iter()
.flat_map(|&sym| {
let func = &self.functions[&sym];
if !func.schema.output.is_eq_sort() {
return vec![];
}
assert!(func.schema.output.is_eq_sort());

func.nodes
.iter(false)
.filter(|&(_, output)| {
func.schema.output.name() == output_sort && output.value == output_value
})
.map(|(inputs, _output)| {
let node = Node { sym, func, inputs };
ext.expr_from_node(&node, termdag).expect(
"extract_variants should be called after extractor initialization",
)
})
.collect()
})
.take(limit)
.collect()
/// Extracts up to `limit` terms for a given `value`.
pub fn extract_variants(&mut self, sort: &ArcSort, value: Value, limit: usize) -> Vec<Term> {
let mut cost_map = None;
let mut termdag = std::mem::take(&mut self.termdag);
if self.cost_cache.is_some() {
let cost_map_ts = std::mem::take(&mut self.cost_cache).unwrap();
if cost_map_ts.0 == self.timestamp {
cost_map = Some(cost_map_ts.1);
}
}
let extractor = Extractor::new(self, &mut termdag, cost_map);
let result = extractor.find_variants(value, &mut termdag, sort, limit);
self.cost_cache = Some((self.timestamp, extractor.cost_map()));
self.termdag = termdag;
result
}
}

impl<'a> Extractor<'a> {
pub fn new(egraph: &'a EGraph, termdag: &mut TermDag) -> Self {
pub fn new(egraph: &'a EGraph, termdag: &mut TermDag, cost_map: Option<CostMap>) -> Self {
let mut extractor = Extractor {
costs: HashMap::default(),
egraph,
Expand All @@ -125,7 +115,11 @@ impl<'a> Extractor<'a> {
);

log::debug!("Extracting from ctors: {:?}", extractor.ctors);
extractor.find_costs(termdag);
if let Some(cost_map) = cost_map {
extractor.costs = cost_map;
} else {
extractor.find_costs(termdag);
}
extractor
}

Expand Down Expand Up @@ -159,6 +153,46 @@ impl<'a> Extractor<'a> {
}
}

/// Extracts up to `limit` terms for a given `value`.
pub fn find_variants(
&self,
value: Value,
termdag: &mut TermDag,
sort: &ArcSort,
limit: usize,
) -> Vec<Term> {
let output_sort = sort.name();
let output_value = self.egraph.find(sort, value);
let terms = self
.ctors
.iter()
.flat_map(|&sym| {
let func = &self.egraph.functions[&sym];
if !func.schema.output.is_eq_sort() {
return vec![];
}
assert!(func.schema.output.is_eq_sort());

func.nodes
.iter(false)
.filter(|&(_, output)| {
func.schema.output.name() == output_sort && output.value == output_value
})
.map(|(inputs, _output)| {
let node = Node { sym, func, inputs };
self.expr_from_node(&node, termdag).expect(
"extract_variants should be called after extractor initialization",
)
})
.collect()
})
.take(limit)
.collect::<Vec<Term>>();

// TODO: what happens if `terms` is empty?
terms
}

fn node_total_cost(
&mut self,
function: &Function,
Expand Down Expand Up @@ -209,4 +243,8 @@ impl<'a> Extractor<'a> {
}
}
}

pub fn cost_map(self) -> CostMap {
self.costs
}
}
60 changes: 40 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use ast::*;
#[cfg(feature = "bin")]
pub use cli::bin::*;
use constraint::{Constraint, SimpleTypeConstraint, TypeConstraint};
use extract::Extractor;
use extract::{CostMap, Extractor};
pub use function::Function;
use function::*;
use gj::*;
Expand Down Expand Up @@ -223,15 +223,8 @@ impl Display for RunReport {
/// A report of the results of an extract action.
#[derive(Debug, Clone)]
pub enum ExtractReport {
Best {
termdag: TermDag,
cost: usize,
term: Term,
},
Variants {
termdag: TermDag,
terms: Vec<Term>,
},
Best { cost: usize, term: Term },
Variants { terms: Vec<Term> },
}

impl RunReport {
Expand Down Expand Up @@ -449,6 +442,8 @@ pub struct EGraph {
overall_run_report: RunReport,
/// Messages to be printed to the user. If this is `None`, then we are ignoring messages.
msgs: Option<Vec<String>>,
cost_cache: Option<(u32, CostMap)>,
termdag: TermDag,
}

impl Default for EGraph {
Expand All @@ -471,6 +466,8 @@ impl Default for EGraph {
overall_run_report: Default::default(),
msgs: Some(vec![]),
type_info: Default::default(),
cost_cache: None,
termdag: Default::default(),
};
egraph
.rulesets
Expand Down Expand Up @@ -730,8 +727,15 @@ impl EGraph {
.map(|(k, v)| (ValueVec::from(k), v.clone()))
.collect::<Vec<_>>();

let mut cost_map = None;
if self.cost_cache.is_some() {
let cost_map_ts = std::mem::take(&mut self.cost_cache).unwrap();
if cost_map_ts.0 == self.timestamp {
cost_map = Some(cost_map_ts.1);
}
}
let mut termdag = TermDag::default();
let extractor = Extractor::new(self, &mut termdag);
let extractor = Extractor::new(self, &mut termdag, cost_map);
let mut terms = Vec::new();
for (ins, out) in nodes {
let mut children = Vec::new();
Expand Down Expand Up @@ -762,7 +766,10 @@ impl EGraph {
};
terms.push((termdag.app(sym, children), out));
}
drop(extractor);

self.cost_cache = Some((self.timestamp, extractor.cost_map()));
// TODO: this clone is bad but this is just to avoid introducing breaking changes to termdag
self.termdag = termdag.clone();

Ok((terms, termdag))
}
Expand Down Expand Up @@ -876,15 +883,22 @@ impl EGraph {
/// Extract a value to a [`TermDag`] and [`Term`] in the [`TermDag`].
/// Note that the `TermDag` may contain a superset of the nodes in the `Term`.
/// See also `extract_value_to_string` for convenience.
pub fn extract_value(&self, sort: &ArcSort, value: Value) -> Result<(TermDag, Term), Error> {
let mut termdag = TermDag::default();
let term = self.extract(value, &mut termdag, sort)?.1;
Ok((termdag, term))
pub fn extract_value(
&mut self,
sort: &ArcSort,
value: Value,
) -> Result<(&mut TermDag, Term), Error> {
let term = self.extract(value, sort)?.1;
Ok((&mut self.termdag, term))
}

/// Extract a value to a string for printing.
/// See also `extract_value` for more control.
pub fn extract_value_to_string(&self, sort: &ArcSort, value: Value) -> Result<String, Error> {
pub fn extract_value_to_string(
&mut self,
sort: &ArcSort,
value: Value,
) -> Result<String, Error> {
let (termdag, term) = self.extract_value(sort, value)?;
Ok(termdag.to_string(&term))
}
Expand Down Expand Up @@ -1313,13 +1327,12 @@ impl EGraph {
.create(true)
.open(&filename)
.map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
let mut termdag = TermDag::default();
for expr in exprs {
let value = self.eval_resolved_expr(&expr)?;
let expr_type = expr.output_type();
let term = self.extract(value, &mut termdag, &expr_type)?.1;
let term = self.extract(value, &expr_type)?.1;
use std::io::Write;
writeln!(f, "{}", termdag.to_string(&term))
writeln!(f, "{}", self.termdag().to_string(&term))
.map_err(|e| Error::IoError(filename.clone(), e, span.clone()))?;
}

Expand Down Expand Up @@ -1549,6 +1562,13 @@ impl EGraph {
vec![]
}
}

/// Gets a reference to the internal TermDag structure EGraph keeps for extracted programs.
///
/// The TermDag is refreshed when (1) a new `extract` is performed AND (2) the internal timestamp of the EGraph is increased.
pub fn termdag(&self) -> &TermDag {
&self.termdag
}
}

// Currently, only the following errors can thrown without location information:
Expand Down
Loading
Loading