Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/shamcomm/include/shamcomm/collectives.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ namespace shamcomm {
void gather_basic_str(
const std::basic_string<byte> &send_vec, std::basic_string<byte> &recv_vec);

/**
* @brief Allgathers a string from all nodes and concatenates it in a std::string
*
* This function gathers the string `send_vec` from all nodes and concatenates the
* result in `recv_vec` on every rank. The result is ordered by the order of the
* nodes in the communicator, i.e. the string is ordered by rank.
*/
void allgather_str(const std::string &send_vec, std::string &recv_vec);

/// same as allgather_str but with std::basic_string
void allgather_basic_str(
const std::basic_string<byte> &send_vec, std::basic_string<byte> &recv_vec);

/**
* @brief Constructs a histogram from a vector of strings, counting occurrences
* of each unique string.
Expand All @@ -56,8 +69,11 @@ namespace shamcomm {
* @return An unordered map where keys are unique strings from the input and
* values are the counts of their occurrences. (valid only on rank 0)
*/

std::unordered_map<std::string, int> string_histogram(
const std::vector<std::string> &inputs, std::string delimiter = "\n");

/// same as string_histogram but with result return on every rank
std::unordered_map<std::string, int> all_string_histogram(
const std::vector<std::string> &inputs, std::string delimiter = "\n");

} // namespace shamcomm
4 changes: 4 additions & 0 deletions src/shamcomm/include/shamcomm/wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "shambase/aliases_float.hpp"
#include "shambase/aliases_int.hpp"
#include "shamcomm/mpi.hpp"
#include <unordered_map>
#include <string>

namespace shamcomm::mpi {
Expand All @@ -29,6 +30,9 @@ namespace shamcomm::mpi {
/// get a timer value
f64 get_timer(std::string timername);

/// return all internal timers
const std::unordered_map<std::string, f64> &get_timers();

/// MPI wrapper for MPI_Allreduce
void Allreduce(
const void *sendbuf,
Expand Down
82 changes: 82 additions & 0 deletions src/shamcomm/src/collectives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,56 @@ namespace {
recv_vec = result;
}

/**
* @brief Allgather a vector of characters from all MPI ranks into a single string
*
* The resulting string is concatenated in rank order and is returned on every rank.
*/
template<class Tchar>
inline void _internal_allgather_str(
const std::basic_string<Tchar> &send_vec, std::basic_string<Tchar> &recv_vec) {
StackEntry stack_loc{};

if (shamcomm::world_size() == 1) {
recv_vec = send_vec;
return;
}

i32 wsize = shamcomm::world_size();
size_t wsize_sz = static_cast<size_t>(wsize);

// counts/displacements are expressed in number of characters.
std::vector<int> counts(wsize_sz);
std::vector<int> disps(wsize_sz);

// MPI counts/displacements use `int`.
int local_count = static_cast<int>(send_vec.size());

shamcomm::mpi::Allgather(
&local_count, 1, MPI_INT, counts.data(), 1, MPI_INT, MPI_COMM_WORLD);

for (size_t i = 0; i < wsize_sz; i++) {
disps[i] = (i > 0) ? (disps[i - 1] + counts[i - 1]) : 0;
}
Comment thread
tdavidcl marked this conversation as resolved.

int global_len = disps[wsize_sz - 1] + counts[wsize_sz - 1];

std::basic_string<Tchar> result;
result.resize(static_cast<size_t>(global_len));

shamcomm::mpi::Allgatherv(
send_vec.data(),
local_count,
MPI_CHAR,
result.data(),
counts.data(),
disps.data(),
MPI_CHAR,
Comment thread
tdavidcl marked this conversation as resolved.
MPI_COMM_WORLD);

recv_vec = result;
}

} // namespace

void shamcomm::gather_str(const std::string &send_vec, std::string &recv_vec) {
Expand All @@ -94,6 +144,17 @@ void shamcomm::gather_basic_str(
_internal_gather_str(send_vec, recv_vec);
}

void shamcomm::allgather_str(const std::string &send_vec, std::string &recv_vec) {
StackEntry stack_loc{};
_internal_allgather_str(send_vec, recv_vec);
}

void shamcomm::allgather_basic_str(
const std::basic_string<byte> &send_vec, std::basic_string<byte> &recv_vec) {
StackEntry stack_loc{};
_internal_allgather_str(send_vec, recv_vec);
}

std::unordered_map<std::string, int> shamcomm::string_histogram(
const std::vector<std::string> &inputs, std::string delimiter) {
std::string accum_loc = "";
Expand All @@ -119,3 +180,24 @@ std::unordered_map<std::string, int> shamcomm::string_histogram(

return {};
}

std::unordered_map<std::string, int> shamcomm::all_string_histogram(
const std::vector<std::string> &inputs, std::string delimiter) {
std::string accum_loc = "";
for (auto &s : inputs) {
accum_loc += s + delimiter;
}
Comment on lines +186 to +189
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Repeatedly concatenating strings with += inside a loop can be inefficient due to multiple reallocations. It's more performant to first calculate the total required size, reserve the memory for the string, and then append the parts. This avoids intermediate allocations.

Example:

std::string accum_loc;
size_t total_size = 0;
for (const auto& s : inputs) {
    total_size += s.size() + delimiter.size();
}
accum_loc.reserve(total_size);
for (const auto &s : inputs) {
    accum_loc.append(s);
    accum_loc.append(delimiter);
}


std::string recv = "";
allgather_str(accum_loc, recv);

std::vector<std::string> splitted = shambase::split_str(recv, delimiter);

std::unordered_map<std::string, int> histogram;

for (size_t i = 0; i < splitted.size(); i++) {
histogram[splitted[i]] += 1;
}
Comment on lines +198 to +200
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and to follow modern C++ practices, you can use a range-based for loop here.

Suggested change
for (size_t i = 0; i < splitted.size(); i++) {
histogram[splitted[i]] += 1;
}
for (const auto& s : splitted) {
histogram[s]++;
}


return histogram;
}
2 changes: 2 additions & 0 deletions src/shamcomm/src/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ namespace shamcomm::mpi {

f64 get_timer(std::string timername) { return mpi_timers[timername]; }

const std::unordered_map<std::string, f64> &get_timers() { return mpi_timers; }

} // namespace shamcomm::mpi

namespace {
Expand Down
59 changes: 59 additions & 0 deletions src/shampylib/src/pyShamcomm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// -------------------------------------------------------//
//
// SHAMROCK code for hydrodynamics
// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
//
// -------------------------------------------------------//

/**
* @file pyShamcomm.cpp
* @author Timothée David--Cléris (tim.shamrock@proton.me)
* @brief
*/

#include "shamalgs/collective/reduction.hpp"
#include "shambindings/pybind11_stl.hpp"
#include "shambindings/pybindaliases.hpp"
#include "shambindings/pytypealias.hpp"
#include "shamcomm/collectives.hpp"
#include "shamcomm/logs.hpp"
#include "shamcomm/wrapper.hpp"
#include <pybind11/pytypes.h>
#include <unordered_map>
#include <utility>
#include <vector>

Register_pymod(shamcommlibinit) {

py::module shamcomm_module = m.def_submodule("comm", "comm library");

shamcomm_module.def("get_timer", [](std::string name) {
return shamcomm::mpi::get_timer(std::move(name));
});

shamcomm_module.def("get_timers", []() {
return shamcomm::mpi::get_timers();
});

shamcomm_module.def(
"mpi_timers_delta",
[](std::unordered_map<std::string, f64> start, std::unordered_map<std::string, f64> end) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The start and end maps are passed by value, which can cause unnecessary and potentially expensive copies. Passing them by const reference (const std::unordered_map<std::string, f64>&) would be more efficient. Note that if you make this change, you will need to use find() or at() instead of operator[] to access map elements, as operator[] is not a const operation.

Suggested change
[](std::unordered_map<std::string, f64> start, std::unordered_map<std::string, f64> end) {
[](const std::unordered_map<std::string, f64>& start, const std::unordered_map<std::string, f64>& end) {

std::vector<std::string> keys{};

for (auto &[k, v] : end) {
keys.push_back(k);
}
Comment on lines +43 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid potential reallocations while populating the keys vector, it's more efficient to reserve its size beforehand using end.size().

Suggested change
std::vector<std::string> keys{};
for (auto &[k, v] : end) {
keys.push_back(k);
}
std::vector<std::string> keys;
keys.reserve(end.size());
for (auto const& [k, v] : end) {
keys.push_back(k);
}


auto key_histo = shamcomm::all_string_histogram(keys);

std::unordered_map<std::string, f64> deltas{};

for (auto &[k, c] : key_histo) {
deltas[k] = shamalgs::collective::allreduce_max(end[k] - start[k]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using operator[] on the maps can have the side effect of inserting a new element if the key doesn't exist. While this may be the intended behavior (to treat missing timers as 0), it's safer and clearer to use find() or count() to check for existence and retrieve the value. This also becomes necessary if you change the parameters to be const references for efficiency.

Example with count():

double end_val = end.count(k) ? end.at(k) : 0.0;
double start_val = start.count(k) ? start.at(k) : 0.0;
deltas[k] = shamalgs::collective::allreduce_max(end_val - start_val);

}

return deltas;
});
}
23 changes: 23 additions & 0 deletions src/tests/shamcomm/collectivesTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,26 @@ TestStart(Unittest, "shamcomm/collectives::gather_str", test_gather_str, 4) {

REQUIRE_EQUAL(recv, result);
}

TestStart(Unittest, "shamcomm/collectives::allgather_str", test_allgather_str, 4) {

std::array<std::string, 4> ref_base{
"I'm a very important string",
"But I'm a very important string",
"Listen, I'm a very important string",
"The most importantest string",
};

std::string result = "";
for (u32 i = 0; i < ref_base.size(); i++) {
result += ref_base[i];
}
Comment on lines +50 to +53
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop for string concatenation can be written more concisely using std::accumulate from the <numeric> header. While performance is not critical in this test, it's a good practice to use standard algorithms where applicable for clarity and conciseness.

    std::string result = std::accumulate(ref_base.begin(), ref_base.end(), std::string{});
References
  1. Refactor duplicated logic into a helper function or lambda to improve readability and maintainability. Using std::accumulate replaces a manual loop with a standard algorithm.
  2. In tests, prefer programmatic construction of expected data collections over manual, verbose initialization to improve maintainability and robustness. Using standard algorithms for test data construction improves maintainability.


std::string send = ref_base[shamcomm::world_rank()];

std::string recv = "random string"; // Just to check that it is overwritten

shamcomm::allgather_str(send, recv);

REQUIRE_EQUAL(recv, result);
}
Loading