Skip to content

Commit 4cbd9da

Browse files
committed
Cleanup and comments
1 parent 91fd0a8 commit 4cbd9da

1 file changed

Lines changed: 19 additions & 9 deletions

File tree

SeQuant/core/optimize/single_term.hpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
#ifndef SEQUANT_CORE_OPTIMIZE_SINGLE_TERM_HPP
22
#define SEQUANT_CORE_OPTIMIZE_SINGLE_TERM_HPP
33

4+
#include <SeQuant/core/algorithm.hpp>
45
#include <SeQuant/core/container.hpp>
56
#include <SeQuant/core/expr.hpp>
7+
#include <SeQuant/core/tensor_canonicalizer.hpp>
68
#include <SeQuant/core/tensor_network.hpp>
79
#include <SeQuant/core/utility/indices.hpp>
810
#include <SeQuant/core/utility/macros.hpp>
911
#include <SeQuant/external/bliss/graph.hh>
1012

1113
#include <range/v3/view.hpp>
1214

13-
#include <SeQuant/core/algorithm.hpp>
14-
#include <SeQuant/core/tensor_canonicalizer.hpp>
1515
#include <algorithm>
1616
#include <bit>
17+
#include <limits>
1718
#include <type_traits>
1819

1920
namespace sequant::opt {
@@ -87,18 +88,18 @@ struct SubNetEqual {
8788
/// \tparam CostFn A function object type that computes the cost of a single
8889
/// binary contraction.
8990
/// Expected signature:
90-
// \code double(meta::range_of<Index> auto const& lhs,
91-
// meta::range_of<Index> auto const& rhs,
91+
/// \code double(meta::range_of<Index> auto const& lhs,
92+
/// meta::range_of<Index> auto const& rhs,
9293
/// meta::range_of<Index> auto const& res)
93-
// \endcode
94+
/// \endcode
9495
///
9596
/// \param network The \ref TensorNetwork containing the tensors to be
9697
/// contracted.
97-
// \param tidxs The set of indices that should remain open in the
98+
/// \param tidxs The set of indices that should remain open in the
9899
/// final result.
99-
// \param cost_fn The cost model used to evaluate contractions
100+
/// \param cost_fn The cost model used to evaluate contractions
100101
/// (e.g., flop count).
101-
// \param subnet_cse If true, enables Common Subexpression
102+
/// \param subnet_cse If true, enables Common Subexpression
102103
/// Elimination (CSE) for
103104
/// equivalent subnetworks. When enabled, the cost of
104105
/// evaluating structurally identical subnetworks is counted
@@ -154,10 +155,14 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
154155
}
155156

156157
// precompute all subnet_meta if subnet_cse is true
158+
// Note: the O(2^n) cost is bounded in practice — subset_target_indices above
159+
// asserts n <= 24, capping the number of subsets at ~16M.
157160
container::vector<uint16_t> meta_ids;
158161
container::vector<double> unique_meta_costs;
159162
if (subnet_cse) {
160-
meta_ids.resize(results.size(), 0);
163+
// Use max as sentinel for entries with popcount < 2 (singletons/empty),
164+
// which are skipped below and never assigned a real meta ID.
165+
meta_ids.resize(results.size(), std::numeric_limits<uint16_t>::max());
161166
container::unordered_map<TensorNetwork::SlotCanonicalizationMetadata,
162167
uint16_t, SubNetHash, SubNetEqual>
163168
meta_to_id;
@@ -194,6 +199,8 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
194199
double new_cost = 0;
195200
container::vector<uint16_t> combined_subnets;
196201
if (subnet_cse) {
202+
// subnets is always kept sorted; set_union requires sorted inputs and
203+
// produces sorted output — this invariant is maintained throughout.
197204
std::set_union(results[lp].subnets.begin(), results[lp].subnets.end(),
198205
results[rp].subnets.begin(), results[rp].subnets.end(),
199206
std::back_inserter(combined_subnets));
@@ -222,6 +229,9 @@ EvalSequence single_term_opt_impl(TensorNetwork const& network,
222229

223230
if (subnet_cse) {
224231
auto mid = meta_ids[n];
232+
// Canonically equivalent subnetworks share the same topology and index
233+
// sizes, so their cost is identical. Overwriting with a later bitmask's
234+
// cost is intentional and benign.
225235
unique_meta_costs[mid] =
226236
cost_fn(results[results[n].lp].indices,
227237
results[results[n].rp].indices, results[n].indices);

0 commit comments

Comments
 (0)