Skip to content

Commit 48f9f5f

Browse files
authored
Merge pull request #506 from ValeevGroup/STO_with_TN_canon
feat: common subnetwork recognition (intra-term) within STO
2 parents b8307ec + b3a1372 commit 48f9f5f

3 files changed

Lines changed: 256 additions & 21 deletions

File tree

SeQuant/core/eval/eval_expr.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ EvalExprNode binarize(Product const& prod, IndexSet const& uncontract) {
461461
auto counts = get_used_indices_with_counts(prod);
462462
IndexGroups<IndexVec> result;
463463
for (auto&& [k, v] : counts) {
464+
if (v.nonproto() == 0) continue;
464465
if (v.total() > 1) {
465466
if (uncontracted_idxs.contains(k)) result.aux.emplace_back(k);
466467
continue;

SeQuant/core/optimize/single_term.hpp

Lines changed: 157 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
#ifndef SEQUANT_CORE_OPTIMIZE_SINGLE_TERM_HPP
22
#define SEQUANT_CORE_OPTIMIZE_SINGLE_TERM_HPP
33

4+
#include <SeQuant/core/algorithm.hpp>
45
#include <SeQuant/core/container.hpp>
56
#include <SeQuant/core/expr.hpp>
7+
#include <SeQuant/core/tensor_canonicalizer.hpp>
68
#include <SeQuant/core/tensor_network.hpp>
79
#include <SeQuant/core/utility/indices.hpp>
810
#include <SeQuant/core/utility/macros.hpp>
11+
#include <SeQuant/external/bliss/graph.hh>
912

1013
#include <range/v3/view.hpp>
1114

15+
#include <algorithm>
1216
#include <bit>
17+
#include <limits>
1318
#include <type_traits>
1419

1520
namespace sequant::opt {
@@ -42,22 +47,83 @@ auto constexpr flops_counter(has_index_extent auto&& ixex) {
4247
///
4348
struct OptRes {
4449
/// Free indices remaining upon evaluation
45-
container::svector<sequant::Index> indices;
50+
IndexSet indices;
4651

4752
/// The flops count of evaluation
4853
double flops;
4954

5055
/// The evaluation sequence
5156
EvalSequence sequence;
57+
58+
/// Bitmask splits that resulted into this OptRes
59+
size_t lp = 0;
60+
size_t rp = 0;
61+
62+
/// unique canonical subnets in the optimal tree for this bitmask
63+
container::vector<size_t> subnets;
64+
};
65+
66+
struct SubNetHash {
67+
size_t operator()(
68+
TensorNetwork::SlotCanonicalizationMetadata const& data) const noexcept {
69+
return data.hash_value();
70+
}
71+
};
72+
73+
struct SubNetEqual {
74+
bool operator()(
75+
TensorNetwork::SlotCanonicalizationMetadata const& left,
76+
TensorNetwork::SlotCanonicalizationMetadata const& right) const {
77+
return bliss::ConstGraphCmp::cmp(*left.graph, *right.graph) == 0;
78+
}
5279
};
5380

81+
/// \brief Finds the optimal evaluation sequence for a single-term tensor
82+
/// contraction.
83+
///
84+
/// This function employs an exhaustive search using dynamic programming to
85+
/// determine the contraction order that minimizes the total cost, as defined by
86+
/// the provided cost function.
87+
///
88+
/// \tparam CostFn A function object type that computes the cost of a single
89+
/// binary contraction.
90+
/// Expected signature:
91+
/// \code double(meta::range_of<Index> auto const& lhs,
92+
/// meta::range_of<Index> auto const& rhs,
93+
/// meta::range_of<Index> auto const& res)
94+
/// \endcode
95+
///
96+
/// \param network The \ref TensorNetwork containing the tensors to be
97+
/// contracted.
98+
/// \param tidxs The set of indices that should remain open in the
99+
/// final result.
100+
/// \param cost_fn The cost model used to evaluate contractions
101+
/// (e.g., flop count).
102+
/// \param subnet_cse If true, enables Common Subexpression
103+
/// Elimination (CSE) for
104+
/// equivalent subnetworks. When enabled, the cost of
105+
/// evaluating structurally identical subnetworks is counted
106+
/// only once in the total cost of a contraction tree.
107+
/// Equivalence is determined by canonicalizing the subnetwork
108+
/// graph.
109+
///
110+
/// \return An \ref EvalSequence representing the optimal contraction order.
111+
///
112+
/// \details The optimization uses a bitmask-based dynamic programming approach
113+
/// where each state represents a subnetwork (subset of tensors).
114+
/// If \p subnet_cse is enabled, the algorithm precomputes canonical
115+
/// metadata for every possible subnetwork to identify common
116+
/// structures. This allows it to find trees that benefit from reusing
117+
/// intermediate results, which is particularly effective for
118+
/// expressions with repeating tensor patterns.
119+
///
54120
template <typename CostFn>
55121
requires requires(CostFn&& fn, decltype(OptRes::indices) const& ixs) {
56-
{ std::forward<CostFn>(fn)(ixs, ixs, ixs) } -> std::floating_point;
122+
{ fn(ixs, ixs, ixs) } -> std::floating_point;
57123
}
58124
EvalSequence single_term_opt_impl(TensorNetwork const& network,
59125
meta::range_of<Index> auto const& tidxs,
60-
CostFn&& cost_fn) {
126+
CostFn&& cost_fn, bool subnet_cse) {
61127
using ranges::views::concat;
62128
using ranges::views::indirect;
63129
using ranges::views::transform;
@@ -88,26 +154,96 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
88154
}
89155
}
90156

157+
// precompute all subnet_meta if subnet_cse is true
158+
// Note: the O(2^n) cost is bounded in practice — subset_target_indices above
159+
// asserts n <= 24, capping the number of subsets at ~16M.
160+
container::vector<size_t> meta_ids;
161+
container::vector<double> unique_meta_costs;
162+
if (subnet_cse) {
163+
// Use max as sentinel for entries with popcount < 2 (singletons/empty),
164+
// which are skipped below and never assigned a real meta ID.
165+
meta_ids.resize(results.size(), std::numeric_limits<size_t>::max());
166+
container::unordered_map<TensorNetwork::SlotCanonicalizationMetadata,
167+
size_t, SubNetHash, SubNetEqual>
168+
meta_to_id;
169+
170+
for (size_t n = 0; n < results.size(); ++n) {
171+
if (std::popcount(n) < 2) continue;
172+
auto ts = bits::on_bits_index(n) | bits::sieve(network.tensors());
173+
container::vector<ExprPtr> ts_expr;
174+
for (auto&& t : ts) {
175+
ts_expr.emplace_back(std::dynamic_pointer_cast<Tensor>(t)->clone());
176+
}
177+
auto tn = TensorNetwork{ts_expr};
178+
auto meta = tn.canonicalize_slots(
179+
TensorCanonicalizer::cardinal_tensor_labels(), &results[n].indices);
180+
181+
auto [it, inserted] = meta_to_id.try_emplace(std::move(meta), 0);
182+
if (inserted) {
183+
it->second = meta_to_id.size() - 1;
184+
}
185+
meta_ids[n] = it->second;
186+
}
187+
unique_meta_costs.resize(meta_to_id.size(), 0.0);
188+
}
189+
91190
// find the optimal evaluation sequence
92191
for (size_t n = 0; n < results.size(); ++n) {
93192
if (std::popcount(n) < 2) continue;
94-
std::pair<size_t, size_t> curr_parts{0, 0};
95193
for (auto& curr_cost = results[n].flops;
96194
auto&& [lp, rp] : bits::bipartitions(n)) {
97195
// do nothing with the trivial bipartition
98196
// i.e. one subset is the empty set and the other full
99197
if (lp == 0 || rp == 0) continue;
100-
auto new_cost = std::forward<CostFn>(cost_fn)(results[lp].indices, //
101-
results[rp].indices, //
102-
results[n].indices) //
103-
+ results[lp].flops + results[rp].flops;
198+
199+
double new_cost = 0;
200+
container::vector<size_t> combined_subnets;
201+
if (subnet_cse) {
202+
// subnets is always kept sorted; set_union requires sorted inputs and
203+
// produces sorted output — this invariant is maintained throughout.
204+
std::set_union(results[lp].subnets.begin(), results[lp].subnets.end(),
205+
results[rp].subnets.begin(), results[rp].subnets.end(),
206+
std::back_inserter(combined_subnets));
207+
new_cost = cost_fn(results[lp].indices, //
208+
results[rp].indices, //
209+
results[n].indices);
210+
for (auto id : combined_subnets) {
211+
new_cost += unique_meta_costs[id];
212+
}
213+
} else {
214+
new_cost = cost_fn(results[lp].indices, //
215+
results[rp].indices, //
216+
results[n].indices) //
217+
+ results[lp].flops + results[rp].flops;
218+
}
219+
104220
if (new_cost <= curr_cost) {
105221
curr_cost = new_cost;
106-
curr_parts = decltype(curr_parts){lp, rp};
222+
results[n].lp = lp;
223+
results[n].rp = rp;
224+
if (subnet_cse) {
225+
results[n].subnets = std::move(combined_subnets);
226+
}
107227
}
108228
}
109-
auto const& lseq = results[curr_parts.first].sequence;
110-
auto const& rseq = results[curr_parts.second].sequence;
229+
230+
if (subnet_cse) {
231+
auto mid = meta_ids[n];
232+
// Canonically equivalent subnetworks share the same topology and index
233+
// sizes, so their cost is identical. Overwriting with a later bitmask's
234+
// cost is intentional and benign.
235+
unique_meta_costs[mid] =
236+
cost_fn(results[results[n].lp].indices,
237+
results[results[n].rp].indices, results[n].indices);
238+
auto it = std::lower_bound(results[n].subnets.begin(),
239+
results[n].subnets.end(), mid);
240+
if (it == results[n].subnets.end() || *it != mid) {
241+
results[n].subnets.insert(it, mid);
242+
}
243+
}
244+
245+
auto const& lseq = results[results[n].lp].sequence;
246+
auto const& rseq = results[results[n].rp].sequence;
111247
results[n].sequence =
112248
(lseq[0] < rseq[0] ? concat(lseq, rseq) : concat(rseq, lseq)) |
113249
ranges::to<EvalSequence>;
@@ -121,15 +257,19 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
121257
/// \tparam IdxToSz
122258
/// \param network A TensorNetwork object.
123259
/// \param idxsz An invocable on Index, that maps Index to its dimension.
124-
/// \return Optimal evaluation sequence that minimizes flops. If there are
260+
/// \param subnet_cse Whether to recognize equivalent subnetworks to try
261+
/// minimizing the ops counts.
262+
/// \return Optimal evaluation sequence that
263+
/// minimizes flops. If there are
125264
/// equivalent optimal sequences then the result is the one that keeps
126265
/// the order of tensors in the network as original as possible.
127266
///
128267
template <has_index_extent IdxToSz>
129-
EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz) {
268+
EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz,
269+
bool subnet_cse) {
130270
auto cost_fn = flops_counter(std::forward<IdxToSz>(idxsz));
131271
decltype(OptRes::indices) tidxs{};
132-
return single_term_opt_impl(network, tidxs, cost_fn);
272+
return single_term_opt_impl(network, tidxs, cost_fn, subnet_cse);
133273
}
134274

135275
} // namespace detail
@@ -142,7 +282,8 @@ EvalSequence single_term_opt(TensorNetwork const& network, IdxToSz&& idxsz) {
142282
/// @note @c prod is assumed to consist of only Tensor expressions
143283
///
144284
template <has_index_extent IdxToSz>
145-
ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz) {
285+
ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz,
286+
bool subnet_cse = false) {
146287
using ranges::views::filter;
147288
using ranges::views::reverse;
148289

@@ -152,7 +293,7 @@ ExprPtr single_term_opt(Product const& prod, IdxToSz&& idxsz) {
152293
auto const tensors =
153294
prod | filter(&ExprPtr::template is<Tensor>) | ranges::to_vector;
154295
auto seq = detail::single_term_opt(TensorNetwork{tensors},
155-
std::forward<IdxToSz>(idxsz));
296+
std::forward<IdxToSz>(idxsz), subnet_cse);
156297
auto result = container::svector<ExprPtr>{};
157298
for (auto i : seq)
158299
if (i == -1) {

tests/unit/test_optimize.cpp

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include <cstddef>
1919
#include <initializer_list>
2020
#include <memory>
21-
#include <stdexcept>
2221

2322
#include <range/v3/all.hpp>
2423

@@ -323,9 +322,103 @@ TEST_CASE("optimize", "[optimize]") {
323322
}
324323
}
325324

325+
SECTION("Single term optimization with CSE") {
326+
auto ctx_resetter =
327+
set_scoped_default_context(get_default_context().clone());
328+
auto reg = get_default_context().mutable_index_space_registry();
329+
mbpt::add_df_spaces(reg);
330+
mbpt::add_pao_spaces(reg);
331+
mbpt::add_ao_spaces(reg);
332+
// i 10
333+
// a 40
334+
// μ̃ 50
335+
// Κ 90
336+
for (auto&& [k, v] :
337+
std::initializer_list<std::pair<std::wstring_view, size_t>>{
338+
{L"i", 10}, {L"a", 40}, {L"μ̃", 50}, {L"Κ", 90}}) {
339+
reg->retrieve_ptr(k)->approximate_size(v);
340+
}
341+
342+
auto single_term_opt = [](Product const& prod, bool cse = true) {
343+
return opt::single_term_opt(
344+
prod,
345+
[](Index const& ix) {
346+
// null space contributes x1 to the size
347+
auto sz = ix.nonnull() ? ix.space().approximate_size() : 1;
348+
return sz;
349+
},
350+
/*subnet_cse=*/cse);
351+
};
352+
353+
auto prod9 =
354+
deserialize("X{i1;a1} X{i2;a2} Y{a2;i3} Y{a1;i4}")->as<Product>();
355+
auto res9 = single_term_opt(prod9);
356+
auto res9_no_cse = single_term_opt(prod9, false);
357+
// this is the one we want to find
358+
// (X Y) (X Y)
359+
REQUIRE(extract(res9, {0, 0}) == prod9.at(0));
360+
REQUIRE(extract(res9, {0, 1}) == prod9.at(3));
361+
REQUIRE(extract(res9, {1, 0}) == prod9.at(1));
362+
REQUIRE(extract(res9, {1, 1}) == prod9.at(2));
363+
364+
// take a look at res9_no_cse for a result with subnet_cse disabled
365+
// should give the same result in this case as it's already optimal
366+
REQUIRE(extract(res9_no_cse, {0, 0}) == prod9.at(0));
367+
REQUIRE(extract(res9_no_cse, {0, 1}) == prod9.at(3));
368+
REQUIRE(extract(res9_no_cse, {1, 0}) == prod9.at(1));
369+
REQUIRE(extract(res9_no_cse, {1, 1}) == prod9.at(2));
370+
371+
SECTION("CSE effect on optimization result") {
372+
auto ctx_resetter =
373+
set_scoped_default_context(get_default_context().clone());
374+
auto reg = get_default_context().mutable_index_space_registry();
375+
// Use sizes that make the unbalanced tree better without CSE,
376+
// but the balanced tree better with CSE.
377+
// Balanced: ( (X1 Y1) (X2 Y2) )
378+
// Cost(X1*Y1) = size(i)*size(a)*size(j) = 12*10*12 = 1440.
379+
// Cost(Inter) = 12^3 = 1728.
380+
// Total no-CSE: 2*1440 + 1728 = 4608.
381+
// Total CSE: 1440 + 1728 = 3168.
382+
// Unbalanced: ( ( (X1 Y1) X2 ) Y2 )
383+
// Cost(X1*Y1) = 12*10*12 = 1440.
384+
// Cost((X1*Y1)*X2) = size(i)*size(i)*size(a) = 12*12*10 = 1440.
385+
// Cost(...) * Y2 = 12*10*12 = 1440.
386+
// Total Unbalanced: 1440 + 1440 + 1440 = 4320.
387+
// 3168 < 4320 < 4608.
388+
reg->retrieve_ptr(L"i")->approximate_size(12);
389+
reg->retrieve_ptr(L"a")->approximate_size(10);
390+
391+
auto single_term_opt = [](Product const& prod, bool cse) {
392+
return opt::single_term_opt(
393+
prod,
394+
[](Index const& ix) {
395+
return ix.nonnull() ? ix.space().approximate_size() : 1;
396+
},
397+
cse);
398+
};
399+
400+
// X{i1;a1} Y{a1;i2} X{i2;a2} Y{a2;i3}
401+
auto prod =
402+
deserialize(L"X{i1;a1} Y{a1;i2} X{i2;a2} Y{a2;i3}")->as<Product>();
403+
404+
auto res_cse = single_term_opt(prod, true);
405+
auto res_no_cse = single_term_opt(prod, false);
406+
407+
// With CSE: Balanced tree
408+
REQUIRE(res_cse->as<Product>().factors().size() == 2);
409+
REQUIRE(res_cse->at(0)->is<Product>());
410+
REQUIRE(res_cse->at(1)->is<Product>());
411+
412+
// Without CSE: Unbalanced tree
413+
bool is_unbalanced =
414+
(res_no_cse->at(0)->is<Tensor>() || res_no_cse->at(1)->is<Tensor>());
415+
REQUIRE(is_unbalanced);
416+
}
417+
}
418+
326419
/// verify that space changes did not leak
327-
auto reg = get_default_context().index_space_registry();
328-
auto uocc = reg->retrieve_ptr(L"a");
329-
REQUIRE(uocc);
330-
REQUIRE(uocc->approximate_size() == 10);
420+
auto reg_check = get_default_context().index_space_registry();
421+
auto uocc_check = reg_check->retrieve_ptr(L"a");
422+
REQUIRE(uocc_check);
423+
REQUIRE(uocc_check->approximate_size() == 10);
331424
}

0 commit comments

Comments
 (0)