11#ifndef SEQUANT_CORE_OPTIMIZE_SINGLE_TERM_HPP
22#define SEQUANT_CORE_OPTIMIZE_SINGLE_TERM_HPP
33
4+ #include < SeQuant/core/algorithm.hpp>
45#include < SeQuant/core/container.hpp>
56#include < SeQuant/core/expr.hpp>
7+ #include < SeQuant/core/tensor_canonicalizer.hpp>
68#include < SeQuant/core/tensor_network.hpp>
79#include < SeQuant/core/utility/indices.hpp>
810#include < SeQuant/core/utility/macros.hpp>
11+ #include < SeQuant/external/bliss/graph.hh>
912
1013#include < range/v3/view.hpp>
1114
15+ #include < algorithm>
1216#include < bit>
17+ #include < limits>
1318#include < type_traits>
1419
1520namespace sequant ::opt {
@@ -42,22 +47,83 @@ auto constexpr flops_counter(has_index_extent auto&& ixex) {
4247// /
4348struct OptRes {
4449 // / Free indices remaining upon evaluation
45- container::svector<sequant::Index> indices;
50+ IndexSet indices;
4651
4752 // / The flops count of evaluation
4853 double flops;
4954
5055 // / The evaluation sequence
5156 EvalSequence sequence;
57+
58+ // / Bitmask splits that resulted into this OptRes
59+ size_t lp = 0 ;
60+ size_t rp = 0 ;
61+
62+ // / unique canonical subnets in the optimal tree for this bitmask
63+ container::vector<size_t > subnets;
64+ };
65+
66+ struct SubNetHash {
67+ size_t operator ()(
68+ TensorNetwork::SlotCanonicalizationMetadata const & data) const noexcept {
69+ return data.hash_value ();
70+ }
71+ };
72+
73+ struct SubNetEqual {
74+ bool operator ()(
75+ TensorNetwork::SlotCanonicalizationMetadata const & left,
76+ TensorNetwork::SlotCanonicalizationMetadata const & right) const {
77+ return bliss::ConstGraphCmp::cmp (*left.graph , *right.graph ) == 0 ;
78+ }
5279};
5380
81+ // / \brief Finds the optimal evaluation sequence for a single-term tensor
82+ // / contraction.
83+ // /
84+ // / This function employs an exhaustive search using dynamic programming to
85+ // / determine the contraction order that minimizes the total cost, as defined by
86+ // / the provided cost function.
87+ // /
88+ // / \tparam CostFn A function object type that computes the cost of a single
89+ // / binary contraction.
90+ // / Expected signature:
91+ // / \code double(meta::range_of<Index> auto const& lhs,
92+ // / meta::range_of<Index> auto const& rhs,
93+ // / meta::range_of<Index> auto const& res)
94+ // / \endcode
95+ // /
96+ // / \param network The \ref TensorNetwork containing the tensors to be
97+ // / contracted.
98+ // / \param tidxs The set of indices that should remain open in the
99+ // / final result.
100+ // / \param cost_fn The cost model used to evaluate contractions
101+ // / (e.g., flop count).
102+ // / \param subnet_cse If true, enables Common Subexpression
103+ // / Elimination (CSE) for
104+ // / equivalent subnetworks. When enabled, the cost of
105+ // / evaluating structurally identical subnetworks is counted
106+ // / only once in the total cost of a contraction tree.
107+ // / Equivalence is determined by canonicalizing the subnetwork
108+ // / graph.
109+ // /
110+ // / \return An \ref EvalSequence representing the optimal contraction order.
111+ // /
112+ // / \details The optimization uses a bitmask-based dynamic programming approach
113+ // / where each state represents a subnetwork (subset of tensors).
114+ // / If \p subnet_cse is enabled, the algorithm precomputes canonical
115+ // / metadata for every possible subnetwork to identify common
116+ // / structures. This allows it to find trees that benefit from reusing
117+ // / intermediate results, which is particularly effective for
118+ // / expressions with repeating tensor patterns.
119+ // /
54120template <typename CostFn>
55121 requires requires (CostFn&& fn, decltype (OptRes::indices) const & ixs) {
56- { std::forward<CostFn>(fn) (ixs, ixs, ixs) } -> std::floating_point;
122+ { fn (ixs, ixs, ixs) } -> std::floating_point;
57123 }
58124EvalSequence single_term_opt_impl (TensorNetwork const & network,
59125 meta::range_of<Index> auto const & tidxs,
60- CostFn&& cost_fn) {
126+ CostFn&& cost_fn, bool subnet_cse ) {
61127 using ranges::views::concat;
62128 using ranges::views::indirect;
63129 using ranges::views::transform;
@@ -88,26 +154,96 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
88154 }
89155 }
90156
157+ // precompute all subnet_meta if subnet_cse is true
158+ // Note: the O(2^n) cost is bounded in practice — subset_target_indices above
159+ // asserts n <= 24, capping the number of subsets at ~16M.
160+ container::vector<size_t > meta_ids;
161+ container::vector<double > unique_meta_costs;
162+ if (subnet_cse) {
163+ // Use max as sentinel for entries with popcount < 2 (singletons/empty),
164+ // which are skipped below and never assigned a real meta ID.
165+ meta_ids.resize (results.size (), std::numeric_limits<size_t >::max ());
166+ container::unordered_map<TensorNetwork::SlotCanonicalizationMetadata,
167+ size_t , SubNetHash, SubNetEqual>
168+ meta_to_id;
169+
170+ for (size_t n = 0 ; n < results.size (); ++n) {
171+ if (std::popcount (n) < 2 ) continue ;
172+ auto ts = bits::on_bits_index (n) | bits::sieve (network.tensors ());
173+ container::vector<ExprPtr> ts_expr;
174+ for (auto && t : ts) {
175+ ts_expr.emplace_back (std::dynamic_pointer_cast<Tensor>(t)->clone ());
176+ }
177+ auto tn = TensorNetwork{ts_expr};
178+ auto meta = tn.canonicalize_slots (
179+ TensorCanonicalizer::cardinal_tensor_labels (), &results[n].indices );
180+
181+ auto [it, inserted] = meta_to_id.try_emplace (std::move (meta), 0 );
182+ if (inserted) {
183+ it->second = meta_to_id.size () - 1 ;
184+ }
185+ meta_ids[n] = it->second ;
186+ }
187+ unique_meta_costs.resize (meta_to_id.size (), 0.0 );
188+ }
189+
91190 // find the optimal evaluation sequence
92191 for (size_t n = 0 ; n < results.size (); ++n) {
93192 if (std::popcount (n) < 2 ) continue ;
94- std::pair<size_t , size_t > curr_parts{0 , 0 };
95193 for (auto & curr_cost = results[n].flops ;
96194 auto && [lp, rp] : bits::bipartitions (n)) {
97195 // do nothing with the trivial bipartition
98196 // i.e. one subset is the empty set and the other full
99197 if (lp == 0 || rp == 0 ) continue ;
100- auto new_cost = std::forward<CostFn>(cost_fn)(results[lp].indices , //
101- results[rp].indices , //
102- results[n].indices ) //
103- + results[lp].flops + results[rp].flops ;
198+
199+ double new_cost = 0 ;
200+ container::vector<size_t > combined_subnets;
201+ if (subnet_cse) {
202+ // subnets is always kept sorted; set_union requires sorted inputs and
203+ // produces sorted output — this invariant is maintained throughout.
204+ std::set_union (results[lp].subnets .begin (), results[lp].subnets .end (),
205+ results[rp].subnets .begin (), results[rp].subnets .end (),
206+ std::back_inserter (combined_subnets));
207+ new_cost = cost_fn (results[lp].indices , //
208+ results[rp].indices , //
209+ results[n].indices );
210+ for (auto id : combined_subnets) {
211+ new_cost += unique_meta_costs[id];
212+ }
213+ } else {
214+ new_cost = cost_fn (results[lp].indices , //
215+ results[rp].indices , //
216+ results[n].indices ) //
217+ + results[lp].flops + results[rp].flops ;
218+ }
219+
104220 if (new_cost <= curr_cost) {
105221 curr_cost = new_cost;
106- curr_parts = decltype (curr_parts){lp, rp};
222+ results[n].lp = lp;
223+ results[n].rp = rp;
224+ if (subnet_cse) {
225+ results[n].subnets = std::move (combined_subnets);
226+ }
107227 }
108228 }
109- auto const & lseq = results[curr_parts.first ].sequence ;
110- auto const & rseq = results[curr_parts.second ].sequence ;
229+
230+ if (subnet_cse) {
231+ auto mid = meta_ids[n];
232+ // Canonically equivalent subnetworks share the same topology and index
233+ // sizes, so their cost is identical. Overwriting with a later bitmask's
234+ // cost is intentional and benign.
235+ unique_meta_costs[mid] =
236+ cost_fn (results[results[n].lp ].indices ,
237+ results[results[n].rp ].indices , results[n].indices );
238+ auto it = std::lower_bound (results[n].subnets .begin (),
239+ results[n].subnets .end (), mid);
240+ if (it == results[n].subnets .end () || *it != mid) {
241+ results[n].subnets .insert (it, mid);
242+ }
243+ }
244+
245+ auto const & lseq = results[results[n].lp ].sequence ;
246+ auto const & rseq = results[results[n].rp ].sequence ;
111247 results[n].sequence =
112248 (lseq[0 ] < rseq[0 ] ? concat (lseq, rseq) : concat (rseq, lseq)) |
113249 ranges::to<EvalSequence>;
@@ -121,15 +257,19 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
121257// / \tparam IdxToSz
122258// / \param network A TensorNetwork object.
123259// / \param idxsz An invocable on Index, that maps Index to its dimension.
124- // / \return Optimal evaluation sequence that minimizes flops. If there are
260+ // / \param subnet_cse Whether to recognize equivalent subnetworks to try
261+ // / minimizing the ops counts.
262+ // / \return Optimal evaluation sequence that
263+ // / minimizes flops. If there are
125264// / equivalent optimal sequences then the result is the one that keeps
126265// / the order of tensors in the network as original as possible.
127266// /
128267template <has_index_extent IdxToSz>
129- EvalSequence single_term_opt (TensorNetwork const & network, IdxToSz&& idxsz) {
268+ EvalSequence single_term_opt (TensorNetwork const & network, IdxToSz&& idxsz,
269+ bool subnet_cse) {
130270 auto cost_fn = flops_counter (std::forward<IdxToSz>(idxsz));
131271 decltype (OptRes::indices) tidxs{};
132- return single_term_opt_impl (network, tidxs, cost_fn);
272+ return single_term_opt_impl (network, tidxs, cost_fn, subnet_cse );
133273}
134274
135275} // namespace detail
@@ -142,7 +282,8 @@ EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz) {
142282// / @note @c prod is assumed to consist of only Tensor expressions
143283// /
144284template <has_index_extent IdxToSz>
145- ExprPtr single_term_opt (Product const & prod, IdxToSz&& idxsz) {
285+ ExprPtr single_term_opt (Product const & prod, IdxToSz&& idxsz,
286+ bool subnet_cse = false ) {
146287 using ranges::views::filter;
147288 using ranges::views::reverse;
148289
@@ -152,7 +293,7 @@ ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz) {
152293 auto const tensors =
153294 prod | filter (&ExprPtr::template is<Tensor>) | ranges::to_vector;
154295 auto seq = detail::single_term_opt (TensorNetwork{tensors},
155- std::forward<IdxToSz>(idxsz));
296+ std::forward<IdxToSz>(idxsz), subnet_cse );
156297 auto result = container::svector<ExprPtr>{};
157298 for (auto i : seq)
158299 if (i == -1 ) {
0 commit comments