Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ class ConstIntBoundAnalyzer {
friend class ConstraintContext;
explicit ConstIntBoundAnalyzer(AnalyzerObj* parent);
TVM_DLL ~ConstIntBoundAnalyzer();
void CopyFrom(const ConstIntBoundAnalyzer& other);
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
Expand Down Expand Up @@ -259,6 +260,7 @@ class ModularSetAnalyzer {
friend class ConstraintContext;
explicit ModularSetAnalyzer(AnalyzerObj* parent);
TVM_DLL ~ModularSetAnalyzer();
void CopyFrom(const ModularSetAnalyzer& other);
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
Expand Down Expand Up @@ -413,6 +415,7 @@ class RewriteSimplifier {
friend class CanonicalSimplifier;
explicit RewriteSimplifier(AnalyzerObj* parent);
TVM_DLL ~RewriteSimplifier();
void CopyFrom(const RewriteSimplifier& other);
class Impl;
/*! \brief Internal impl */
Impl* impl_;
Expand Down Expand Up @@ -444,6 +447,7 @@ class CanonicalSimplifier {
friend class ConstraintContext;
explicit CanonicalSimplifier(AnalyzerObj* parent);
TVM_DLL ~CanonicalSimplifier();
void CopyFrom(const CanonicalSimplifier& other);
class Impl;
/*! \brief Internal impl */
Impl* impl_;
Expand Down Expand Up @@ -529,6 +533,7 @@ class TransitiveComparisonAnalyzer {
friend class ConstraintContext;
TransitiveComparisonAnalyzer();
TVM_DLL ~TransitiveComparisonAnalyzer();
void CopyFrom(const TransitiveComparisonAnalyzer& other);
class Impl;
/*! \brief Internal impl */
std::unique_ptr<Impl> impl_;
Expand Down Expand Up @@ -583,6 +588,7 @@ class IntSetAnalyzer {
friend class AnalyzerObj;
explicit IntSetAnalyzer(AnalyzerObj* parent);
TVM_DLL ~IntSetAnalyzer();
void CopyFrom(const IntSetAnalyzer& other);
class Impl;
/*! \brief Internal impl */
Impl* impl_;
Expand Down Expand Up @@ -747,6 +753,26 @@ class TVM_DLL AnalyzerObj : public ffi::Object {
*/
PrimExpr Simplify(const PrimExpr& expr, int steps = 2);

/*!
* \brief Deep-copy this analyzer into a new, independent Analyzer.
*
* The returned analyzer carries the same accumulated facts (variable
* bounds, modular sets, rewrite/canonical bindings, integer-set domains,
* literal constraints and transitive comparisons) as this one, but owns
* its own state: binding or simplifying on either analyzer afterwards does
* not affect the other. This is the deep copy that handle-copying an
* Analyzer does not provide.
*
* \note Do not call this while a `With<ConstraintContext>` scope is active
* on this analyzer. The clone would inherit the scoped constraints
* but not the recovery functions that pop them on scope exit, so the
* constraints would leak as if they were global facts. Clone at a
* point where no constraint scope is in effect.
*
* \return A new Analyzer holding an independent copy of the facts.
*/
Analyzer Clone() const;

/*!
* \brief Analyzer methods update facts, constraints, caches, and stats.
*
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,24 @@ def simplify(self, expr: tirx.PrimExpr, steps: int = 2) -> tirx.PrimExpr:
"""
return _ffi_api.AnalyzerSimplify(self, expr, steps)

def clone(self) -> "Analyzer":
"""Return a deep copy of this analyzer with independent state.

The returned analyzer carries the same accumulated facts (variable
bounds, modular sets, bindings, integer-set domains, literal
constraints and transitive comparisons) as this one, but owns its own
state: binding or simplifying on either analyzer afterwards does not
affect the other. Unlike copying the handle, this is a true deep copy.

Do not call this while a constraint scope is active on this analyzer.

Returns
-------
result : Analyzer
A new analyzer holding an independent copy of the facts.
"""
return _ffi_api.AnalyzerClone(self)

def rewrite_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr:
"""Simplify expression via rewriting rules.

Expand Down
12 changes: 12 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,23 @@ PrimExpr AnalyzerObj::Simplify(const PrimExpr& expr, int steps) {
return res;
}

Analyzer AnalyzerObj::Clone() const {
Analyzer cloned;
cloned->const_int_bound.CopyFrom(this->const_int_bound);
cloned->modular_set.CopyFrom(this->modular_set);
cloned->rewrite_simplify.CopyFrom(this->rewrite_simplify);
cloned->canonical_simplify.CopyFrom(this->canonical_simplify);
cloned->int_set.CopyFrom(this->int_set);
cloned->transitive_comparisons.CopyFrom(this->transitive_comparisons);
return cloned;
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AnalyzerObj>();
refl::GlobalDef()
.def("arith.Analyzer", []() { return Analyzer(); })
.def("arith.AnalyzerClone", [](Analyzer analyzer) { return analyzer->Clone(); })
.def("arith.AnalyzerConstIntBound",
[](Analyzer analyzer, const PrimExpr& expr) { return analyzer->const_int_bound(expr); })
.def("arith.AnalyzerConstIntBoundUpdate",
Expand Down
4 changes: 4 additions & 0 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1454,5 +1454,9 @@ CanonicalSimplifier::CanonicalSimplifier(AnalyzerObj* parent) : impl_(new Impl(p

CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; }

void CanonicalSimplifier::CopyFrom(const CanonicalSimplifier& other) {
impl_->CopyFrom(*other.impl_);
}

} // namespace arith
} // namespace tvm
9 changes: 9 additions & 0 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,11 @@ class ConstIntBoundAnalyzer::Impl
return frecover;
}

void CopyFrom(const Impl& other) {
var_map_ = other.var_map_;
additional_info_ = other.additional_info_;
}

private:
friend class ConstIntBoundAnalyzer;
// parent analyzer
Expand Down Expand Up @@ -859,5 +864,9 @@ ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(AnalyzerObj* parent) : impl_(new Im

ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; }

void ConstIntBoundAnalyzer::CopyFrom(const ConstIntBoundAnalyzer& other) {
impl_->CopyFrom(*other.impl_);
}

} // namespace arith
} // namespace tvm
7 changes: 7 additions & 0 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,11 @@ class IntSetAnalyzer::Impl {
void Bind(const Var& var, const PrimExpr& expr, bool override_info);
std::function<void()> EnterConstraint(const PrimExpr& constraint);

void CopyFrom(const Impl& other) {
dom_map_ = other.dom_map_;
dom_constraints_ = other.dom_constraints_;
}

private:
// Utility function to split a boolean condition into the domain
// bounds implied by that condition.
Expand All @@ -681,6 +686,8 @@ IntSetAnalyzer::IntSetAnalyzer(AnalyzerObj* parent) : impl_(new Impl(parent)) {}

IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; }

void IntSetAnalyzer::CopyFrom(const IntSetAnalyzer& other) { impl_->CopyFrom(*other.impl_); }

IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const ffi::Map<Var, IntSet>& dom_map) {
return impl_->Eval(expr, dom_map);
}
Expand Down
6 changes: 6 additions & 0 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
return Everything();
}

void CopyFrom(const Impl& other) { var_map_ = other.var_map_; }

private:
/*! \brief pointer to parent. */
AnalyzerObj* parent_{nullptr};
Expand Down Expand Up @@ -407,5 +409,9 @@ ModularSetAnalyzer::ModularSetAnalyzer(AnalyzerObj* parent) : impl_(new Impl(par

ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; }

void ModularSetAnalyzer::CopyFrom(const ModularSetAnalyzer& other) {
impl_->CopyFrom(*other.impl_);
}

} // namespace arith
} // namespace tvm
2 changes: 2 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2466,6 +2466,8 @@ RewriteSimplifier::RewriteSimplifier(AnalyzerObj* parent) : impl_(new Impl(paren

RewriteSimplifier::~RewriteSimplifier() { delete impl_; }

void RewriteSimplifier::CopyFrom(const RewriteSimplifier& other) { impl_->CopyFrom(*other.impl_); }

// Pattern A (RM): auto-default repr from reflection.

} // namespace arith
Expand Down
7 changes: 7 additions & 0 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {

void SetMaximumRewriteSteps(int64_t maximum) { maximum_rewrite_steps_ = maximum; }

void CopyFrom(const Impl& other) {
var_map_ = other.var_map_;
literal_constraints_ = other.literal_constraints_;
enabled_extensions_ = other.enabled_extensions_;
maximum_rewrite_steps_ = other.maximum_rewrite_steps_;
}

protected:
int64_t maximum_rewrite_steps_{0};
RewriteSimplifierStatsNode stats_;
Expand Down
11 changes: 11 additions & 0 deletions src/arith/transitive_comparison_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ class TransitiveComparisonAnalyzer::Impl {
*/
std::function<void()> EnterConstraint(const PrimExpr& expr);

void CopyFrom(const Impl& other) {
expr_to_key = other.expr_to_key;
prev_bindings_ = other.prev_bindings_;
knowns_ = other.knowns_;
scoped_knowns_ = other.scoped_knowns_;
}

private:
/* \brief Internal representation of a PrimExpr
*
Expand Down Expand Up @@ -528,6 +535,10 @@ bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}

void TransitiveComparisonAnalyzer::CopyFrom(const TransitiveComparisonAnalyzer& other) {
impl_->CopyFrom(*other.impl_);
}

CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
bool propagate_inequalities) {
return impl_->TryCompare(lhs, rhs, propagate_inequalities);
Expand Down
21 changes: 21 additions & 0 deletions tests/cpp/arith_simplify_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,27 @@ TEST(AnalyzerObjectRef, ConstHandleRefCanMutateAnalyzerState) {
TVM_FFI_ICHECK(analyzer->CanProve(x < 8));
}

TEST(AnalyzerObjectRef, CloneIsIndependent) {
tvm::arith::Analyzer analyzer;
auto x = tvm::te::var("x");
auto y = tvm::te::var("y");

analyzer->Bind(x, tvm::Range::FromMinExtent(0, 8));
analyzer->modular_set.Update(x, tvm::arith::ModularSet(4, 0));

tvm::arith::Analyzer clone = analyzer->Clone();
TVM_FFI_ICHECK(clone->CanProve(x < 8));
TVM_FFI_ICHECK(clone->modular_set(x)->coeff == 4);

clone->Bind(y, tvm::Range::FromMinExtent(0, 4));
clone->modular_set.Update(x, tvm::arith::ModularSet(8, 0), true);
TVM_FFI_ICHECK(clone->CanProve(y < 4));
TVM_FFI_ICHECK(!analyzer->CanProve(y < 4));
TVM_FFI_ICHECK(analyzer->CanProve(x < 8));
TVM_FFI_ICHECK(analyzer->modular_set(x)->coeff == 4);
TVM_FFI_ICHECK(clone->modular_set(x)->coeff == 8);
}

TEST(ConstantFold, Broadcast) {
tvm::ffi::StructuralEqual checker;
auto i32x4 = tvm::tirx::Broadcast(tvm::IntImm::Int32(10), 4);
Expand Down
79 changes: 79 additions & 0 deletions tests/python/arith/test_arith_analyzer_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,5 +204,84 @@ def test_analyzer_object_state_persists_across_ffi_calls():
tvm.ir.assert_structural_equal(analyzer.simplify(tile), tvm.tirx.const(8, "int32"))


def test_analyzer_object_clone_is_independent():
analyzer = tvm.arith.Analyzer()
x = tirx.Var("x", "int64")
y = tirx.Var("y", "int64")
z = tirx.Var("z", "int64")

analyzer.bind(x, tvm.ir.Range(0, 8))

clone = analyzer.clone()
assert clone is not analyzer
assert clone.can_prove(x < 8)

clone.bind(y, tvm.ir.Range(0, 4))
assert clone.can_prove(y < 4)
assert not analyzer.can_prove(y < 4)

analyzer.bind(z, tvm.ir.Range(0, 4))
assert analyzer.can_prove(z < 4)
assert not clone.can_prove(z < 4)

assert analyzer.can_prove(x < 8)
assert clone.can_prove(x < 8)


def test_analyzer_object_clone_copies_every_sub_analyzer():
analyzer = tvm.arith.Analyzer()
x = tirx.Var("x", "int64")
w = tirx.Var("w", "int64")
v = tirx.Var("v", "int64")

analyzer.bind(x, tvm.ir.Range(0, 8))
analyzer.update(x, tvm.arith.ModularSet(4, 0))
analyzer.bind(w, tirx.const(4, "int64"))
analyzer.update(v, tvm.arith.IntervalSet(2, 9))
analyzer.enabled_extensions = Extension.ComparisonOfProductAndSum

clone = analyzer.clone()

assert clone.can_prove(x < 8)
assert clone.modular_set(x).coeff == 4
tvm.ir.assert_structural_equal(clone.simplify(w + 1), tirx.const(5, "int64"))
assert clone.int_set(v).max_value.value == 9
assert clone.enabled_extensions == Extension.ComparisonOfProductAndSum
assert clone.try_compare(x, tirx.const(0, "int64")) == CompareResult.GE

t = tirx.Var("t", "int64")
clone.update(x, tvm.arith.ModularSet(8, 0), override=True)
clone.update(v, tvm.arith.IntervalSet(0, 3), override=True)
clone.bind(w, tirx.const(8, "int64"), allow_override=True)
clone.bind(t, tvm.ir.Range(0, 4))
clone.enabled_extensions = Extension.NoExtensions

assert analyzer.modular_set(x).coeff == 4
assert clone.modular_set(x).coeff == 8
assert analyzer.int_set(v).max_value.value == 9
assert clone.int_set(v).max_value.value == 3
tvm.ir.assert_structural_equal(analyzer.simplify(w + 1), tirx.const(5, "int64"))
tvm.ir.assert_structural_equal(clone.simplify(w + 1), tirx.const(9, "int64"))
assert analyzer.enabled_extensions == Extension.ComparisonOfProductAndSum
assert clone.enabled_extensions == Extension.NoExtensions
assert clone.try_compare(t, tirx.const(0, "int64")) == CompareResult.GE
assert analyzer.try_compare(t, tirx.const(0, "int64")) == CompareResult.UNKNOWN


def test_analyzer_object_clone_resets_rewrite_stats():
analyzer = tvm.arith.Analyzer()
x = tirx.Var("x", "int64")
y = tirx.Var("y", "int64")
analyzer.bind(x, tvm.ir.Range(0, 8))
analyzer.bind(y, tvm.ir.Range(0, 8))
analyzer.simplify((x + y) * 2 - x - y)
source_attempts = analyzer.rewrite_simplify_stats.rewrites_attempted
assert source_attempts > 0

clone = analyzer.clone()
assert clone.rewrite_simplify_stats.rewrites_attempted == 0
assert analyzer.rewrite_simplify_stats.rewrites_attempted == source_attempts


if __name__ == "__main__":
tvm.testing.main()
Loading