Skip to content

Commit 3fa55d2

Browse files
committed
cleanup
1 parent ae413af commit 3fa55d2

2 files changed

Lines changed: 31 additions & 54 deletions

File tree

SeQuant/core/optimize/single_term.hpp

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,6 @@ struct OptRes {
6262
container::vector<uint16_t> subnets;
6363
};
6464

65-
// constexpr auto cse_hasher = [](TNMeta const& data) -> size_t {
66-
// return data.hash_value();
67-
// };
68-
69-
// constexpr auto cse_equal = [](TNMeta const& left,
70-
// TNMeta const& right) -> bool {
71-
// return bliss::ConstGraphCmp::cmp(*left.graph, *right.graph) == 0;
72-
// };
73-
7465
struct SubNetHash {
7566
size_t operator()(
7667
TensorNetwork::SlotCanonicalizationMetadata const& data) const noexcept {
@@ -95,30 +86,35 @@ struct SubNetEqual {
9586
///
9687
/// \tparam CostFn A function object type that computes the cost of a single
9788
/// binary contraction.
98-
/// Expected signature: \code double(meta::range_of<Index> auto
99-
/// const& lhs, meta::range_of<Index> auto const& rhs,
100-
/// meta::range_of<Index> auto const& res) \endcode
89+
/// Expected signature:
90+
// \code double(meta::range_of<Index> auto const& lhs,
91+
// meta::range_of<Index> auto const& rhs,
92+
/// meta::range_of<Index> auto const& res)
93+
// \endcode
10194
///
10295
/// \param network The \ref TensorNetwork containing the tensors to be
103-
/// contracted. \param tidxs The set of indices that should remain open in the
104-
/// final result. \param cost_fn The cost model used to evaluate contractions
105-
/// (e.g., flop count). \param subnet_cse If true, enables Common Subexpression
96+
/// contracted.
97+
// \param tidxs The set of indices that should remain open in the
98+
/// final result.
99+
// \param cost_fn The cost model used to evaluate contractions
100+
/// (e.g., flop count).
101+
// \param subnet_cse If true, enables Common Subexpression
106102
/// Elimination (CSE) for
107-
/// equivalent subnetworks. When enabled, the cost of
108-
/// evaluating structurally identical subnetworks is counted
109-
/// only once in the total cost of a contraction tree.
110-
/// Equivalence is determined by canonicalizing the subnetwork
111-
/// graph.
103+
/// equivalent subnetworks. When enabled, the cost of
104+
/// evaluating structurally identical subnetworks is counted
105+
/// only once in the total cost of a contraction tree.
106+
/// Equivalence is determined by canonicalizing the subnetwork
107+
/// graph.
112108
///
113109
/// \return An \ref EvalSequence representing the optimal contraction order.
114110
///
115111
/// \details The optimization uses a bitmask-based dynamic programming approach
116-
/// where each state represents a subnetwork (subset of tensors).
117-
/// If \p subnet_cse is enabled, the algorithm precomputes canonical
118-
/// metadata for every possible subnetwork to identify common
119-
/// structures. This allows it to find trees that benefit from reusing
120-
/// intermediate results, which is particularly effective for
121-
/// expressions with repeating tensor patterns.
112+
/// where each state represents a subnetwork (subset of tensors).
113+
/// If \p subnet_cse is enabled, the algorithm precomputes canonical
114+
/// metadata for every possible subnetwork to identify common
115+
/// structures. This allows it to find trees that benefit from reusing
116+
/// intermediate results, which is particularly effective for
117+
/// expressions with repeating tensor patterns.
122118
///
123119
template <typename CostFn>
124120
requires requires(CostFn&& fn, decltype(OptRes::indices) const& ixs) {

tests/unit/test_optimize.cpp

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -353,16 +353,21 @@ TEST_CASE("optimize", "[optimize]") {
353353
auto prod9 =
354354
deserialize("X{i1;a1} X{i2;a4} Y{a3;i3} Y{a1;i4}")->as<Product>();
355355
auto res9 = single_term_opt(prod9);
356-
// take a look at res9_ for a result with subnet_cse disabled
357-
// should give different result
358-
// auto res9_ = single_term_opt(prod9, false);
359-
// std::wcout << "res9_\n" << serialize(res9_) << std::endl;
360356
// this is the one we want to find
361357
// (X Y) (X Y)
362358
REQUIRE(extract(res9, {0, 0}) == prod9.at(0));
363359
REQUIRE(extract(res9, {0, 1}) == prod9.at(3));
364360
REQUIRE(extract(res9, {1, 0}) == prod9.at(1));
365361
REQUIRE(extract(res9, {1, 1}) == prod9.at(2));
362+
363+
// take a look at res9_ for a result with subnet_cse disabled
364+
// should give different result
365+
// std::wcout << "res9_\n" << serialize(res9_) << std::endl;
366+
auto res9_no_cse = single_term_opt(prod9, false);
367+
REQUIRE(extract(res9_no_cse, {0, 0, 0}) == prod9.at(0));
368+
REQUIRE(extract(res9_no_cse, {0, 0, 1}) == prod9.at(3));
369+
REQUIRE(extract(res9_no_cse, {0, 1}) == prod9.at(1));
370+
REQUIRE(extract(res9_no_cse, {1}) == prod9.at(2));
366371
}
367372

368373
/// verify that space changes did not leak
@@ -371,27 +376,3 @@ TEST_CASE("optimize", "[optimize]") {
371376
REQUIRE(uocc);
372377
REQUIRE(uocc->approximate_size() == 10);
373378
}
374-
375-
TEST_CASE("feature optimize", "[feature]") {
376-
using namespace sequant;
377-
auto ctx_resetter = set_scoped_default_context(get_default_context().clone());
378-
auto reg = get_default_context().mutable_index_space_registry();
379-
mbpt::add_df_spaces(reg);
380-
mbpt::add_pao_spaces(reg);
381-
mbpt::add_ao_spaces(reg);
382-
// i 10
383-
// a 40
384-
// μ̃ 50
385-
// Κ 90
386-
for (auto&& [k, v] :
387-
std::initializer_list<std::pair<std::wstring_view, size_t>>{
388-
{L"i", 10}, {L"a", 40}, {L"μ̃", 50}, {L"Κ", 90}}) {
389-
reg->retrieve_ptr(k)->approximate_size(v);
390-
}
391-
392-
for (auto&& ix :
393-
std::initializer_list<std::wstring_view>{L"i", L"a", L"μ̃", L"Κ"}) {
394-
std::wcout << std::format(L"{}: {}\n", ix,
395-
reg->retrieve_ptr(ix)->approximate_size());
396-
}
397-
}

0 commit comments

Comments
 (0)