diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 924cc299270a..8a6f4381c69a 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -180,6 +180,7 @@ class ConstIntBoundAnalyzer { friend class ConstraintContext; explicit ConstIntBoundAnalyzer(AnalyzerObj* parent); TVM_DLL ~ConstIntBoundAnalyzer(); + void CopyFrom(const ConstIntBoundAnalyzer& other); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -259,6 +260,7 @@ class ModularSetAnalyzer { friend class ConstraintContext; explicit ModularSetAnalyzer(AnalyzerObj* parent); TVM_DLL ~ModularSetAnalyzer(); + void CopyFrom(const ModularSetAnalyzer& other); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -413,6 +415,7 @@ class RewriteSimplifier { friend class CanonicalSimplifier; explicit RewriteSimplifier(AnalyzerObj* parent); TVM_DLL ~RewriteSimplifier(); + void CopyFrom(const RewriteSimplifier& other); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -444,6 +447,7 @@ class CanonicalSimplifier { friend class ConstraintContext; explicit CanonicalSimplifier(AnalyzerObj* parent); TVM_DLL ~CanonicalSimplifier(); + void CopyFrom(const CanonicalSimplifier& other); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -529,6 +533,7 @@ class TransitiveComparisonAnalyzer { friend class ConstraintContext; TransitiveComparisonAnalyzer(); TVM_DLL ~TransitiveComparisonAnalyzer(); + void CopyFrom(const TransitiveComparisonAnalyzer& other); class Impl; /*! \brief Internal impl */ std::unique_ptr impl_; @@ -583,6 +588,7 @@ class IntSetAnalyzer { friend class AnalyzerObj; explicit IntSetAnalyzer(AnalyzerObj* parent); TVM_DLL ~IntSetAnalyzer(); + void CopyFrom(const IntSetAnalyzer& other); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -747,6 +753,26 @@ class TVM_DLL AnalyzerObj : public ffi::Object { */ PrimExpr Simplify(const PrimExpr& expr, int steps = 2); + /*! + * \brief Deep-copy this analyzer into a new, independent Analyzer. + * + * The returned analyzer carries the same accumulated facts (variable + * bounds, modular sets, rewrite/canonical bindings, integer-set domains, + * literal constraints and transitive comparisons) as this one, but owns + * its own state: binding or simplifying on either analyzer afterwards does + * not affect the other. This is the deep copy that handle-copying an + * Analyzer does not provide. + * + * \note Do not call this while a `With` scope is active + * on this analyzer. The clone would inherit the scoped constraints + * but not the recovery functions that pop them on scope exit, so the + * constraints would leak as if they were global facts. Clone at a + * point where no constraint scope is in effect. + * + * \return A new Analyzer holding an independent copy of the facts. + */ + Analyzer Clone() const; + /*! * \brief Analyzer methods update facts, constraints, caches, and stats. * diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 0aa6a75eba4a..b2ca7ad4a3bb 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -193,6 +193,24 @@ def simplify(self, expr: tirx.PrimExpr, steps: int = 2) -> tirx.PrimExpr: """ return _ffi_api.AnalyzerSimplify(self, expr, steps) + def clone(self) -> "Analyzer": + """Return a deep copy of this analyzer with independent state. + + The returned analyzer carries the same accumulated facts (variable + bounds, modular sets, bindings, integer-set domains, literal + constraints and transitive comparisons) as this one, but owns its own + state: binding or simplifying on either analyzer afterwards does not + affect the other. Unlike copying the handle, this is a true deep copy. + + Do not call this while a constraint scope is active on this analyzer. + + Returns + ------- + result : Analyzer + A new analyzer holding an independent copy of the facts. + """ + return _ffi_api.AnalyzerClone(self) + def rewrite_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: """Simplify expression via rewriting rules. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 69dbe97f5e8a..0937c8e5acc1 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -255,11 +255,23 @@ PrimExpr AnalyzerObj::Simplify(const PrimExpr& expr, int steps) { return res; } +Analyzer AnalyzerObj::Clone() const { + Analyzer cloned; + cloned->const_int_bound.CopyFrom(this->const_int_bound); + cloned->modular_set.CopyFrom(this->modular_set); + cloned->rewrite_simplify.CopyFrom(this->rewrite_simplify); + cloned->canonical_simplify.CopyFrom(this->canonical_simplify); + cloned->int_set.CopyFrom(this->int_set); + cloned->transitive_comparisons.CopyFrom(this->transitive_comparisons); + return cloned; +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); refl::GlobalDef() .def("arith.Analyzer", []() { return Analyzer(); }) + .def("arith.AnalyzerClone", [](Analyzer analyzer) { return analyzer->Clone(); }) .def("arith.AnalyzerConstIntBound", [](Analyzer analyzer, const PrimExpr& expr) { return analyzer->const_int_bound(expr); }) .def("arith.AnalyzerConstIntBoundUpdate", diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index f1dd1a63c5e7..7806c234450a 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1454,5 +1454,9 @@ CanonicalSimplifier::CanonicalSimplifier(AnalyzerObj* parent) : impl_(new Impl(p CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; } +void CanonicalSimplifier::CopyFrom(const CanonicalSimplifier& other) { + impl_->CopyFrom(*other.impl_); +} + } // namespace arith } // namespace tvm diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 8ff1a8b17ecf..4d700564ea05 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -498,6 +498,11 @@ class ConstIntBoundAnalyzer::Impl return frecover; } + void CopyFrom(const Impl& other) { + var_map_ = other.var_map_; + additional_info_ = other.additional_info_; + } + private: friend class ConstIntBoundAnalyzer; // parent analyzer @@ -859,5 +864,9 @@ ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(AnalyzerObj* parent) : impl_(new Im ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } +void ConstIntBoundAnalyzer::CopyFrom(const ConstIntBoundAnalyzer& other) { + impl_->CopyFrom(*other.impl_); +} + } // namespace arith } // namespace tvm diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index a1e01d3e86a0..b68042e2afef 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -658,6 +658,11 @@ class IntSetAnalyzer::Impl { void Bind(const Var& var, const PrimExpr& expr, bool override_info); std::function EnterConstraint(const PrimExpr& constraint); + void CopyFrom(const Impl& other) { + dom_map_ = other.dom_map_; + dom_constraints_ = other.dom_constraints_; + } + private: // Utility function to split a boolean condition into the domain // bounds implied by that condition. @@ -681,6 +686,8 @@ IntSetAnalyzer::IntSetAnalyzer(AnalyzerObj* parent) : impl_(new Impl(parent)) {} IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; } +void IntSetAnalyzer::CopyFrom(const IntSetAnalyzer& other) { impl_->CopyFrom(*other.impl_); } + IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const ffi::Map& dom_map) { return impl_->Eval(expr, dom_map); } diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 5f66356e1ae9..856f5df0b7f9 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -310,6 +310,8 @@ class ModularSetAnalyzer::Impl : public ExprFunctorCopyFrom(*other.impl_); +} + } // namespace arith } // namespace tvm diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index b5b0cc604e22..a824a2553f00 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -2466,6 +2466,8 @@ RewriteSimplifier::RewriteSimplifier(AnalyzerObj* parent) : impl_(new Impl(paren RewriteSimplifier::~RewriteSimplifier() { delete impl_; } +void RewriteSimplifier::CopyFrom(const RewriteSimplifier& other) { impl_->CopyFrom(*other.impl_); } + // Pattern A (RM): auto-default repr from reflection. } // namespace arith diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index b42b73336a27..719aa5ec0701 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -135,6 +135,13 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { void SetMaximumRewriteSteps(int64_t maximum) { maximum_rewrite_steps_ = maximum; } + void CopyFrom(const Impl& other) { + var_map_ = other.var_map_; + literal_constraints_ = other.literal_constraints_; + enabled_extensions_ = other.enabled_extensions_; + maximum_rewrite_steps_ = other.maximum_rewrite_steps_; + } + protected: int64_t maximum_rewrite_steps_{0}; RewriteSimplifierStatsNode stats_; diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index e7deea4cfd56..20fd05169f43 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -82,6 +82,13 @@ class TransitiveComparisonAnalyzer::Impl { */ std::function EnterConstraint(const PrimExpr& expr); + void CopyFrom(const Impl& other) { + expr_to_key = other.expr_to_key; + prev_bindings_ = other.prev_bindings_; + knowns_ = other.knowns_; + scoped_knowns_ = other.scoped_knowns_; + } + private: /* \brief Internal representation of a PrimExpr * @@ -528,6 +535,10 @@ bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique()) {} TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {} +void TransitiveComparisonAnalyzer::CopyFrom(const TransitiveComparisonAnalyzer& other) { + impl_->CopyFrom(*other.impl_); +} + CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs, bool propagate_inequalities) { return impl_->TryCompare(lhs, rhs, propagate_inequalities); diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index ba5305e9dd1b..d5050446d6a5 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -75,6 +75,27 @@ TEST(AnalyzerObjectRef, ConstHandleRefCanMutateAnalyzerState) { TVM_FFI_ICHECK(analyzer->CanProve(x < 8)); } +TEST(AnalyzerObjectRef, CloneIsIndependent) { + tvm::arith::Analyzer analyzer; + auto x = tvm::te::var("x"); + auto y = tvm::te::var("y"); + + analyzer->Bind(x, tvm::Range::FromMinExtent(0, 8)); + analyzer->modular_set.Update(x, tvm::arith::ModularSet(4, 0)); + + tvm::arith::Analyzer clone = analyzer->Clone(); + TVM_FFI_ICHECK(clone->CanProve(x < 8)); + TVM_FFI_ICHECK(clone->modular_set(x)->coeff == 4); + + clone->Bind(y, tvm::Range::FromMinExtent(0, 4)); + clone->modular_set.Update(x, tvm::arith::ModularSet(8, 0), true); + TVM_FFI_ICHECK(clone->CanProve(y < 4)); + TVM_FFI_ICHECK(!analyzer->CanProve(y < 4)); + TVM_FFI_ICHECK(analyzer->CanProve(x < 8)); + TVM_FFI_ICHECK(analyzer->modular_set(x)->coeff == 4); + TVM_FFI_ICHECK(clone->modular_set(x)->coeff == 8); +} + TEST(ConstantFold, Broadcast) { tvm::ffi::StructuralEqual checker; auto i32x4 = tvm::tirx::Broadcast(tvm::IntImm::Int32(10), 4); diff --git a/tests/python/arith/test_arith_analyzer_object.py b/tests/python/arith/test_arith_analyzer_object.py index 4b4c4134b9be..9edd75d7aa7b 100644 --- a/tests/python/arith/test_arith_analyzer_object.py +++ b/tests/python/arith/test_arith_analyzer_object.py @@ -204,5 +204,84 @@ def test_analyzer_object_state_persists_across_ffi_calls(): tvm.ir.assert_structural_equal(analyzer.simplify(tile), tvm.tirx.const(8, "int32")) +def test_analyzer_object_clone_is_independent(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int64") + y = tirx.Var("y", "int64") + z = tirx.Var("z", "int64") + + analyzer.bind(x, tvm.ir.Range(0, 8)) + + clone = analyzer.clone() + assert clone is not analyzer + assert clone.can_prove(x < 8) + + clone.bind(y, tvm.ir.Range(0, 4)) + assert clone.can_prove(y < 4) + assert not analyzer.can_prove(y < 4) + + analyzer.bind(z, tvm.ir.Range(0, 4)) + assert analyzer.can_prove(z < 4) + assert not clone.can_prove(z < 4) + + assert analyzer.can_prove(x < 8) + assert clone.can_prove(x < 8) + + +def test_analyzer_object_clone_copies_every_sub_analyzer(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int64") + w = tirx.Var("w", "int64") + v = tirx.Var("v", "int64") + + analyzer.bind(x, tvm.ir.Range(0, 8)) + analyzer.update(x, tvm.arith.ModularSet(4, 0)) + analyzer.bind(w, tirx.const(4, "int64")) + analyzer.update(v, tvm.arith.IntervalSet(2, 9)) + analyzer.enabled_extensions = Extension.ComparisonOfProductAndSum + + clone = analyzer.clone() + + assert clone.can_prove(x < 8) + assert clone.modular_set(x).coeff == 4 + tvm.ir.assert_structural_equal(clone.simplify(w + 1), tirx.const(5, "int64")) + assert clone.int_set(v).max_value.value == 9 + assert clone.enabled_extensions == Extension.ComparisonOfProductAndSum + assert clone.try_compare(x, tirx.const(0, "int64")) == CompareResult.GE + + t = tirx.Var("t", "int64") + clone.update(x, tvm.arith.ModularSet(8, 0), override=True) + clone.update(v, tvm.arith.IntervalSet(0, 3), override=True) + clone.bind(w, tirx.const(8, "int64"), allow_override=True) + clone.bind(t, tvm.ir.Range(0, 4)) + clone.enabled_extensions = Extension.NoExtensions + + assert analyzer.modular_set(x).coeff == 4 + assert clone.modular_set(x).coeff == 8 + assert analyzer.int_set(v).max_value.value == 9 + assert clone.int_set(v).max_value.value == 3 + tvm.ir.assert_structural_equal(analyzer.simplify(w + 1), tirx.const(5, "int64")) + tvm.ir.assert_structural_equal(clone.simplify(w + 1), tirx.const(9, "int64")) + assert analyzer.enabled_extensions == Extension.ComparisonOfProductAndSum + assert clone.enabled_extensions == Extension.NoExtensions + assert clone.try_compare(t, tirx.const(0, "int64")) == CompareResult.GE + assert analyzer.try_compare(t, tirx.const(0, "int64")) == CompareResult.UNKNOWN + + +def test_analyzer_object_clone_resets_rewrite_stats(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int64") + y = tirx.Var("y", "int64") + analyzer.bind(x, tvm.ir.Range(0, 8)) + analyzer.bind(y, tvm.ir.Range(0, 8)) + analyzer.simplify((x + y) * 2 - x - y) + source_attempts = analyzer.rewrite_simplify_stats.rewrites_attempted + assert source_attempts > 0 + + clone = analyzer.clone() + assert clone.rewrite_simplify_stats.rewrites_attempted == 0 + assert analyzer.rewrite_simplify_stats.rewrites_attempted == source_attempts + + if __name__ == "__main__": tvm.testing.main()