diff --git a/buildbot/lib/buildbot.py b/buildbot/lib/buildbot.py index 49163c8a34..a5d00e4ab1 100644 --- a/buildbot/lib/buildbot.py +++ b/buildbot/lib/buildbot.py @@ -59,24 +59,36 @@ def print_buildbot_info(utility_name): print() - str_git = os.popen("git log -n 1 --decorate=full").read() + try: + r_log = subprocess.run( + ["git", "log", "-n", "1", "--decorate=full"], + capture_output=True, + text=True, + ) + if r_log.returncode != 0: + raise RuntimeError("git log failed") - git_hash = str_git.split()[1] - git_head = str_git[str_git.find("HEAD -> ") + 8 : str_git.find(")")] + str_git = r_log.stdout + git_hash = str_git.split()[1] + git_head = str_git[str_git.find("HEAD -> ") + 8 : str_git.find(")")] - git_head = git_head.split(",") + git_head = git_head.split(",") - if len(git_head) == 1: - git_head = "\033[1;92m" + git_head[0] + "\033[0;0m" - else: - git_head = "\033[1;92m" + git_head[0] + "\033[0;0m , \033[1;91m" + git_head[0] + "\033[0;0m" + if len(git_head) == 1: + git_head = "\033[1;92m" + git_head[0] + "\033[0;0m" + else: + git_head = ( + "\033[1;92m" + git_head[0] + "\033[0;0m , \033[1;91m" + git_head[0] + "\033[0;0m" + ) - print("\033[1;34mGit status \033[0;0m: ") - print(" \033[1;93mcommit \033[0;0m: ", git_hash) - print(" \033[1;36mHEAD \033[0;0m: ", git_head) - print(" \033[1;31mmodified files\033[0;0m (since last commit):") - print(os.popen('git diff-index --name-only HEAD -- | sed "s/^/ /g"').read()) - print("\033[1;90m" + "-" * col_cnt + "\033[0;0m\n") + print("\033[1;34mGit status \033[0;0m: ") + print(" \033[1;93mcommit \033[0;0m: ", git_hash) + print(" \033[1;36mHEAD \033[0;0m: ", git_head) + print(" \033[1;31mmodified files\033[0;0m (since last commit):") + print(os.popen('git diff-index --name-only HEAD -- | sed "s/^/ /g"').read()) + print("\033[1;90m" + "-" * col_cnt + "\033[0;0m\n") + except Exception: # noqa: BLE001 + print("Warn : couldn't get git status") def run_cmd(str): diff --git a/examples/benchmarks/sph_weak_scale_test.py b/examples/benchmarks/sph_weak_scale_test.py index 5a0bda2a99..8ec3c5504d 100644 --- a/examples/benchmarks/sph_weak_scale_test.py +++ b/examples/benchmarks/sph_weak_scale_test.py @@ -63,7 +63,6 @@ ) cfg.set_boundary_periodic() cfg.set_eos_adiabatic(gamma) - cfg.set_max_neigh_cache_size(int(100e9)) cfg.print_status() model.set_solver_config(cfg) model.init_scheduler(scheduler_split_val, scheduler_merge_val) @@ -102,7 +101,7 @@ model.set_value_in_a_box("uint", "f64", 0, bmin, bmax) - rinj = 8 * dr + rinj = 16 * dr u_inj = 1 model.add_kernel_value("uint", "f64", u_inj, (0, 0, 0), rinj) @@ -116,9 +115,6 @@ model.set_cfl_cour(0.1) model.set_cfl_force(0.1) - model.set_cfl_multipler(1e-4) - model.set_cfl_mult_stiffness(1e6) - shamrock.backends.reset_mem_info_max() # converge smoothing length and compute initial dt @@ -128,11 +124,40 @@ res_rates = [] res_cnts = [] res_system_metrics = [] + res_mpi_timers = [] + + """ + Here we insert callbacks to measure solver MPI usage by fetching the timers twice at the begining and end of the step + """ + before_mpi_timers, after_mpi_timers = None, None + + def callback_before_mpi_timer(): + global before_mpi_timers + # print(shamrock.sys.world_rank(), "register before_mpi_timers") + before_mpi_timers = shamrock.comm.get_timers() + + def callback_after_mpi_timer(): + global after_mpi_timers + # print(shamrock.sys.world_rank(), "register after_mpi_timers") + after_mpi_timers = shamrock.comm.get_timers() + + model.add_timestep_callback( + step_begin=callback_before_mpi_timer, step_end=callback_after_mpi_timer + ) + + for i in range(10): + if shamrock.sys.world_rank() == 0: + print("running step ", i + 1, "/", 10, " ...") - for i in range(5): shamrock.sys.mpi_barrier() + + # To replay the same step + model.set_next_dt(0.0) model.timestep() + if shamrock.sys.world_rank() == 0: + print("collecting results ...") + tmp_res_rate, tmp_res_cnt, tmp_system_metrics = ( model.solver_logs_last_rate(), model.solver_logs_last_obj_count(), @@ -141,6 +166,17 @@ res_rates.append(tmp_res_rate) res_cnts.append(tmp_res_cnt) res_system_metrics.append(tmp_system_metrics) + res_mpi_timers.append(shamrock.comm.mpi_timers_delta(before_mpi_timers, after_mpi_timers)) + + if shamrock.sys.world_rank() == 0: + print("sleeping 1 second ...") + + import time + + time.sleep(1) + + if shamrock.sys.world_rank() == 0: + print("done sleeping 1 second ...") # result is the best rate of the 5 steps res_rate, res_cnt = max(res_rates), res_cnts[0] @@ -148,7 +184,7 @@ # index of the max rate max_rate_index = res_rates.index(max(res_rates)) max_rate_system_metrics = res_system_metrics[max_rate_index] - + max_mpi_timers = res_mpi_timers[max_rate_index] step_time = res_cnt / res_rate if shamrock.sys.world_rank() == 0: @@ -168,6 +204,7 @@ "rate": res_rate, "cnt": res_cnt, "step_time": step_time, + "mpi_timers": max_mpi_timers, } # print the system metrics diff --git a/src/shamalgs/src/collective/sparse_exchange.cpp b/src/shamalgs/src/collective/sparse_exchange.cpp index 6e6625544d..970d35d051 100644 --- a/src/shamalgs/src/collective/sparse_exchange.cpp +++ b/src/shamalgs/src/collective/sparse_exchange.cpp @@ -56,6 +56,7 @@ namespace shamalgs::collective { /// fetch u64_2 from global message data std::vector fetch_global_message_data( const std::vector &messages_send) { + __shamrock_stack_entry(); std::vector local_data = std::vector(messages_send.size()); @@ -84,6 +85,7 @@ namespace shamalgs::collective { /// decode message to get message std::vector decode_all_message(const std::vector &global_data) { + __shamrock_stack_entry(); std::vector message_all(global_data.size()); for (u64 i = 0; i < global_data.size(); i++) { message_all[i] = unpack(global_data[i]); @@ -94,6 +96,7 @@ namespace shamalgs::collective { /// compute message tags void compute_tags(std::vector &message_all) { + __shamrock_stack_entry(); std::vector tag_map(shamcomm::world_size(), 0); diff --git a/src/shamcomm/include/shamcomm/collectives.hpp b/src/shamcomm/include/shamcomm/collectives.hpp index 865ee044dc..b9d0a022ee 100644 --- a/src/shamcomm/include/shamcomm/collectives.hpp +++ b/src/shamcomm/include/shamcomm/collectives.hpp @@ -41,6 +41,19 @@ namespace shamcomm { void gather_basic_str( const std::basic_string &send_vec, std::basic_string &recv_vec); + /** + * @brief Allgathers a string from all nodes and concatenates it in a std::string + * + * This function gathers the string `send_vec` from all nodes and concatenates the + * result in `recv_vec` on every rank. The result is ordered by the order of the + * nodes in the communicator, i.e. the string is ordered by rank. + */ + void allgather_str(const std::string &send_vec, std::string &recv_vec); + + /// same as allgather_str but with std::basic_string + void allgather_basic_str( + const std::basic_string &send_vec, std::basic_string &recv_vec); + /** * @brief Constructs a histogram from a vector of strings, counting occurrences * of each unique string. @@ -56,8 +69,11 @@ namespace shamcomm { * @return An unordered map where keys are unique strings from the input and * values are the counts of their occurrences. (valid only on rank 0) */ - std::unordered_map string_histogram( const std::vector &inputs, std::string delimiter = "\n"); + /// same as string_histogram but with result return on every rank + std::unordered_map all_string_histogram( + const std::vector &inputs, std::string delimiter = "\n"); + } // namespace shamcomm diff --git a/src/shamcomm/include/shamcomm/wrapper.hpp b/src/shamcomm/include/shamcomm/wrapper.hpp index 501a218de3..252ee32164 100644 --- a/src/shamcomm/include/shamcomm/wrapper.hpp +++ b/src/shamcomm/include/shamcomm/wrapper.hpp @@ -19,7 +19,9 @@ #include "shambase/aliases_float.hpp" #include "shambase/aliases_int.hpp" #include "shamcomm/mpi.hpp" +#include #include +#include namespace shamcomm::mpi { @@ -29,6 +31,12 @@ namespace shamcomm::mpi { /// get a timer value f64 get_timer(std::string timername); + /// return all internal timers + const std::unordered_map &get_timers(); + + /// return all possible keys for the internal timers + const std::vector &get_possible_keys(); + /// MPI wrapper for MPI_Allreduce void Allreduce( const void *sendbuf, diff --git a/src/shamcomm/src/collectives.cpp b/src/shamcomm/src/collectives.cpp index 934c4555d3..158badcdae 100644 --- a/src/shamcomm/src/collectives.cpp +++ b/src/shamcomm/src/collectives.cpp @@ -81,6 +81,56 @@ namespace { recv_vec = result; } + /** + * @brief Allgather a vector of characters from all MPI ranks into a single string + * + * The resulting string is concatenated in rank order and is returned on every rank. + */ + template + inline void _internal_allgather_str( + const std::basic_string &send_vec, std::basic_string &recv_vec) { + StackEntry stack_loc{}; + + if (shamcomm::world_size() == 1) { + recv_vec = send_vec; + return; + } + + i32 wsize = shamcomm::world_size(); + size_t wsize_sz = static_cast(wsize); + + // counts/displacements are expressed in number of characters. + std::vector counts(wsize_sz); + std::vector disps(wsize_sz); + + // MPI counts/displacements use `int`. + int local_count = static_cast(send_vec.size()); + + shamcomm::mpi::Allgather( + &local_count, 1, MPI_INT, counts.data(), 1, MPI_INT, MPI_COMM_WORLD); + + for (size_t i = 0; i < wsize_sz; i++) { + disps[i] = (i > 0) ? (disps[i - 1] + counts[i - 1]) : 0; + } + + int global_len = disps[wsize_sz - 1] + counts[wsize_sz - 1]; + + std::basic_string result; + result.resize(static_cast(global_len)); + + shamcomm::mpi::Allgatherv( + send_vec.data(), + local_count, + MPI_CHAR, + result.data(), + counts.data(), + disps.data(), + MPI_CHAR, + MPI_COMM_WORLD); + + recv_vec = result; + } + } // namespace void shamcomm::gather_str(const std::string &send_vec, std::string &recv_vec) { @@ -94,6 +144,17 @@ void shamcomm::gather_basic_str( _internal_gather_str(send_vec, recv_vec); } +void shamcomm::allgather_str(const std::string &send_vec, std::string &recv_vec) { + StackEntry stack_loc{}; + _internal_allgather_str(send_vec, recv_vec); +} + +void shamcomm::allgather_basic_str( + const std::basic_string &send_vec, std::basic_string &recv_vec) { + StackEntry stack_loc{}; + _internal_allgather_str(send_vec, recv_vec); +} + std::unordered_map shamcomm::string_histogram( const std::vector &inputs, std::string delimiter) { std::string accum_loc = ""; @@ -119,3 +180,24 @@ std::unordered_map shamcomm::string_histogram( return {}; } + +std::unordered_map shamcomm::all_string_histogram( + const std::vector &inputs, std::string delimiter) { + std::string accum_loc = ""; + for (auto &s : inputs) { + accum_loc += s + delimiter; + } + + std::string recv = ""; + allgather_str(accum_loc, recv); + + std::vector splitted = shambase::split_str(recv, delimiter); + + std::unordered_map histogram; + + for (size_t i = 0; i < splitted.size(); i++) { + histogram[splitted[i]] += 1; + } + + return histogram; +} diff --git a/src/shamcomm/src/wrapper.cpp b/src/shamcomm/src/wrapper.cpp index 3223b0fb98..7335ba6547 100644 --- a/src/shamcomm/src/wrapper.cpp +++ b/src/shamcomm/src/wrapper.cpp @@ -43,6 +43,22 @@ namespace shamcomm::mpi { f64 get_timer(std::string timername) { return mpi_timers[timername]; } + const std::unordered_map &get_timers() { return mpi_timers; } + + std::vector possible_keys{ + "total", "MPI_Isend", "MPI_Irecv", + "MPI_Allreduce", "MPI_Allgather", "MPI_Allgatherv", + "MPI_Exscan", "MPI_Wait", "MPI_Waitall", + "MPI_Barrier", "MPI_Probe", "MPI_Recv", + "MPI_Get_count", "MPI_Send", "MPI_File_set_view", + "MPI_Type_size", "MPI_File_write_all", "MPI_File_write", + "MPI_File_read", "MPI_File_write_at", "MPI_File_read_at", + "MPI_File_close", "MPI_File_open", "MPI_Test", + "MPI_Gather", "MPI_Gatherv", + }; + + const std::vector &get_possible_keys() { return possible_keys; } + } // namespace shamcomm::mpi namespace { diff --git a/src/shammodels/gsph/src/modules/GSPHGhostHandler.cpp b/src/shammodels/gsph/src/modules/GSPHGhostHandler.cpp index 070fc6f858..531a6e7508 100644 --- a/src/shammodels/gsph/src/modules/GSPHGhostHandler.cpp +++ b/src/shammodels/gsph/src/modules/GSPHGhostHandler.cpp @@ -344,13 +344,13 @@ auto GSPHGhostHandler::gen_id_table_interfaces(GeneratorMap &&gen) for (auto &[k, v] : send_count_stats) { if (v > 0.2) { - warn_log += shambase::format("\n patch {} high interf/patch volume: {}", k, v); + // warn_log += shambase::format("\n patch {} high interf/patch volume: {}", k, v); has_warn = true; } } if (has_warn && shamcomm::world_rank() == 0) { - warn_log = "\n This can lead to high mpi " + warn_log = "\n High interf/patch volume. This can lead to high mpi " "overhead, try to increase the patch split crit" + warn_log; } diff --git a/src/shammodels/sph/include/shammodels/sph/SPHUtilities.hpp b/src/shammodels/sph/include/shammodels/sph/SPHUtilities.hpp index af5f17a9bd..80dacb3af9 100644 --- a/src/shammodels/sph/include/shammodels/sph/SPHUtilities.hpp +++ b/src/shammodels/sph/include/shammodels/sph/SPHUtilities.hpp @@ -86,7 +86,14 @@ namespace shammodels::sph { PatchField interactR_patch = sched.map_owned_to_patch_field_simple( [&](const Patch p, PatchDataLayer &pdat) -> flt { if (!pdat.is_empty()) { +#if false + auto tmp = pdat.get_field(ihpart).compute_max() * h_evol_max * Rkern; + shamcomm::logs::raw_ln( + shambase::format("patch {}, Rghost = {}", p.id_patch, tmp)); + return tmp; +#else return pdat.get_field(ihpart).compute_max() * h_evol_max * Rkern; +#endif } else { return shambase::VectorProperties::get_min(); } diff --git a/src/shammodels/sph/include/shammodels/sph/Solver.hpp b/src/shammodels/sph/include/shammodels/sph/Solver.hpp index 10df216c7b..83509cbcd1 100644 --- a/src/shammodels/sph/include/shammodels/sph/Solver.hpp +++ b/src/shammodels/sph/include/shammodels/sph/Solver.hpp @@ -30,9 +30,12 @@ #include "shamrock/scheduler/ShamrockCtx.hpp" #include "shamsys/legacy/log.hpp" #include "shamtree/TreeTraversalCache.hpp" +#include #include +#include #include #include +#include namespace shammodels::sph { struct TimestepLog { @@ -75,6 +78,12 @@ namespace shammodels::sph { Config solver_config; SolverLog solve_logs; + struct SolverStepCallback { + std::optional> step_begin_callback; + std::optional> step_end_callback; + }; + std::vector timestep_callbacks{}; + inline void init_required_fields() { solver_config.set_layout(context.get_pdl_write()); } // serial patch tree control diff --git a/src/shammodels/sph/src/BasicSPHGhosts.cpp b/src/shammodels/sph/src/BasicSPHGhosts.cpp index 6b6357761c..8f5e1a164f 100644 --- a/src/shammodels/sph/src/BasicSPHGhosts.cpp +++ b/src/shammodels/sph/src/BasicSPHGhosts.cpp @@ -560,13 +560,13 @@ auto BasicSPHGhostHandler::gen_id_table_interfaces(GeneratorMap &&gen) for (auto &[k, v] : send_count_stats) { if (v > 0.2) { - warn_log += shambase::format("\n patch {} high interf/patch volume: {}", k, v); + // warn_log += shambase::format("\n patch {} high interf/patch volume: {}", k, v); has_warn = true; } } if (has_warn && shamcomm::world_rank() == 0) { - warn_log = "\n This can lead to high mpi " + warn_log = "\n High interf/patch volume. This can lead to high mpi " "overhead, try to increase the patch split crit" + warn_log; } diff --git a/src/shammodels/sph/src/Solver.cpp b/src/shammodels/sph/src/Solver.cpp index 0ef3a5ad05..66e3001fc8 100644 --- a/src/shammodels/sph/src/Solver.cpp +++ b/src/shammodels/sph/src/Solver.cpp @@ -1583,6 +1583,12 @@ shammodels::sph::TimestepLog shammodels::sph::Solver::evolve_once() sham::MemPerfInfos mem_perf_infos_start = sham::details::get_mem_perf_info(); f64 mpi_timer_start = shamcomm::mpi::get_timer("total"); + for (auto &callbacks : timestep_callbacks) { + if (callbacks.step_begin_callback) { + shambase::get_check_ref(callbacks.step_begin_callback)(); + } + } + Tscal t_current = solver_config.get_time(); Tscal dt = solver_config.get_dt_sph(); @@ -2643,6 +2649,12 @@ shammodels::sph::TimestepLog shammodels::sph::Solver::evolve_once() tstep.end(); + for (auto it = timestep_callbacks.rbegin(); it != timestep_callbacks.rend(); ++it) { + if (it->step_begin_callback) { + shambase::get_check_ref(it->step_end_callback)(); + } + } + f64 delta_mpi_timer = shamcomm::mpi::get_timer("total") - mpi_timer_start; sham::MemPerfInfos mem_perf_infos_end = sham::details::get_mem_perf_info(); diff --git a/src/shammodels/sph/src/pySPHModel.cpp b/src/shammodels/sph/src/pySPHModel.cpp index 74da727ebb..b927346f79 100644 --- a/src/shammodels/sph/src/pySPHModel.cpp +++ b/src/shammodels/sph/src/pySPHModel.cpp @@ -38,7 +38,9 @@ #include #include #include +#include #include +#include template class SPHKernel> void add_instance(py::module &m, std::string name_config, std::string name_model) { @@ -1246,7 +1248,18 @@ void add_instance(py::module &m, std::string name_config, std::string name_model return sched.get_patch_transform(); }) .def("apply_momentum_offset", &T::apply_momentum_offset) - .def("apply_position_offset", &T::apply_position_offset); + .def("apply_position_offset", &T::apply_position_offset) + .def( + "add_timestep_callback", + [](T &self, + std::optional> step_begin_callback, + std::optional> step_end_callback) { + self.solver.timestep_callbacks.push_back( + {std::move(step_begin_callback), std::move(step_end_callback)}); + }, + py::kw_only(), + py::arg("step_begin") = std::nullopt, + py::arg("step_end") = std::nullopt); } template class SPHKernel> diff --git a/src/shampylib/src/pyShamcomm.cpp b/src/shampylib/src/pyShamcomm.cpp new file mode 100644 index 0000000000..c50985332f --- /dev/null +++ b/src/shampylib/src/pyShamcomm.cpp @@ -0,0 +1,51 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2026 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +/** + * @file pyShamcomm.cpp + * @author Timothée David--Cléris (tim.shamrock@proton.me) + * @brief + */ + +#include "shamalgs/collective/reduction.hpp" +#include "shambindings/pybind11_stl.hpp" +#include "shambindings/pybindaliases.hpp" +#include "shambindings/pytypealias.hpp" +#include "shamcomm/collectives.hpp" +#include "shamcomm/logs.hpp" +#include "shamcomm/wrapper.hpp" +#include +#include +#include +#include + +Register_pymod(shamcommlibinit) { + + py::module shamcomm_module = m.def_submodule("comm", "comm library"); + + shamcomm_module.def("get_timer", [](std::string name) { + return shamcomm::mpi::get_timer(std::move(name)); + }); + + shamcomm_module.def("get_timers", []() { + return shamcomm::mpi::get_timers(); + }); + + shamcomm_module.def( + "mpi_timers_delta", + [](std::unordered_map start, std::unordered_map end) { + std::unordered_map deltas{}; + + for (auto &k : shamcomm::mpi::get_possible_keys()) { + deltas[k] = shamalgs::collective::allreduce_max(end[k] - start[k]); + } + + return deltas; + }); +} diff --git a/src/shamrock/include/shamrock/scheduler/SerialPatchTree.hpp b/src/shamrock/include/shamrock/scheduler/SerialPatchTree.hpp index 8449755de1..b2decdfb4b 100644 --- a/src/shamrock/include/shamrock/scheduler/SerialPatchTree.hpp +++ b/src/shamrock/include/shamrock/scheduler/SerialPatchTree.hpp @@ -269,10 +269,13 @@ class SerialPatchTree { sycl::queue &queue, shamrock::patch::PatchField pfield, Func &&reducer) { + __shamrock_stack_entry(); + shamrock::patch::PatchtreeField ptfield; ptfield.allocate(get_element_count()); { + __shamrock_stack_entry(); sycl::host_accessor lpid{ shambase::get_check_ref(linked_patch_ids_buf), sycl::read_only}; sycl::host_accessor tree_field{ @@ -280,6 +283,8 @@ class SerialPatchTree { // init reduction std::unordered_map &idp_to_gid = sched.patch_list.id_patch_to_global_idx; + +#pragma omp parallel for for (u64 idx = 0; idx < get_element_count(); idx++) { tree_field[idx] = (lpid[idx] != u64_max) ? pfield.get(lpid[idx]) : T(); } diff --git a/src/shamrock/include/shamrock/scheduler/loadbalance/LoadBalanceStrategy.hpp b/src/shamrock/include/shamrock/scheduler/loadbalance/LoadBalanceStrategy.hpp index c429d06f4e..db6347a4ba 100644 --- a/src/shamrock/include/shamrock/scheduler/loadbalance/LoadBalanceStrategy.hpp +++ b/src/shamrock/include/shamrock/scheduler/loadbalance/LoadBalanceStrategy.hpp @@ -222,7 +222,8 @@ namespace shamrock::scheduler::details { inline LBMetric compute_LB_metric( const std::vector> &lb_vector, const std::vector &new_owners, - i32 world_size) { + i32 world_size, + f64 strategy_weight) { std::vector load_per_node(world_size, 0); @@ -250,7 +251,11 @@ namespace shamrock::scheduler::details { } var /= world_size; - return {min, max, avg, sycl::sqrt(var)}; + return { + min * strategy_weight, + max * strategy_weight, + avg * strategy_weight, + sycl::sqrt(var) * strategy_weight}; } } // namespace shamrock::scheduler::details @@ -270,30 +275,39 @@ namespace shamrock::scheduler { std::vector> &&lb_vector, i32 world_size = shamcomm::world_size()) { - auto tmpres = details::lb_startegy_parallel_sweep(lb_vector, world_size); - auto metric_psweep = details::compute_LB_metric(lb_vector, tmpres, world_size); + using namespace details; - auto tmpres_2 = details::lb_startegy_roundrobin(lb_vector, world_size); - auto metric_rrobin = details::compute_LB_metric(lb_vector, tmpres_2, world_size); + f64 factor_boost_psweep = 1; + auto tmpres = lb_startegy_parallel_sweep(lb_vector, world_size); + auto metric_psweep = compute_LB_metric(lb_vector, tmpres, world_size, factor_boost_psweep); + // We boost the round robin strategy to favor it if the difference is around 5% since the + // increased uniformity will probably offset the cost anyway + f64 factor_boost_rrobin = 0.95; + auto tmpres_2 = lb_startegy_roundrobin(lb_vector, world_size); + auto metric_rrobin + = compute_LB_metric(lb_vector, tmpres_2, world_size, factor_boost_rrobin); + + std::string strategy_name = "parallel sweep"; if (metric_rrobin.max < metric_psweep.max) { - tmpres = tmpres_2; + tmpres = tmpres_2; + strategy_name = "round robin"; } if (shamcomm::world_rank() == 0) { - logger::info_ln("LoadBalance", "summary :"); - logger::info_ln( - "LoadBalance", - " - strategy \"psweep\" : max =", - metric_psweep.max, - "min =", - metric_psweep.min); logger::info_ln( "LoadBalance", - " - strategy \"round robin\" : max =", - metric_rrobin.max, - "min =", - metric_rrobin.min); + shambase::format( + R"=(Summary (strategy = {0:}): + - strategy "psweep" : max = {1:.1f} min = {2:.1f} factor = {3:} + - strategy "round robin" : max = {4:.1f} min = {5:.1f} factor = {6:})=", + strategy_name, + metric_psweep.max, + metric_psweep.min, + factor_boost_psweep, + metric_rrobin.max, + metric_rrobin.min, + factor_boost_rrobin)); } return tmpres; } diff --git a/src/shamrock/src/solvergraph/ExchangeGhostLayer.cpp b/src/shamrock/src/solvergraph/ExchangeGhostLayer.cpp index c87577db9a..5edb418156 100644 --- a/src/shamrock/src/solvergraph/ExchangeGhostLayer.cpp +++ b/src/shamrock/src/solvergraph/ExchangeGhostLayer.cpp @@ -30,6 +30,28 @@ void shamrock::solvergraph::ExchangeGhostLayer::_impl_evaluate_internal() { auto &ghost_layer = edges.ghost_layer; const shamrock::solvergraph::RankGetter &rank_owner = edges.rank_owner; +#if false + std::unordered_map msg_sizes_send; + std::unordered_map msg_sizes_max_send; + + std::stringstream ss; + ss << "Rank " << shamcomm::world_rank() << " is sending " + << ghost_layer.patchdatas.get_native().size() << " patches sizes:"; + for (auto &pdat : ghost_layer.patchdatas.get_native()) { + u64 key = rank_owner.get_rank_owner(pdat.first.first); + // ss << pdat.first.first << " " << pdat.first.second << " " << pdat.second.get_obj_cnt() << + // "\n"; + msg_sizes_send[key] += pdat.second.get_obj_cnt(); + msg_sizes_max_send[key] = std::max(msg_sizes_max_send[key], u64(pdat.second.get_obj_cnt())); + } + for (auto &[rank, size] : msg_sizes_send) { + ss << "\n" + << "msg size from rank " << rank << " is " << size << " max is " + << msg_sizes_max_send[rank]; + } + shamcomm::logs::raw_ln(ss.str()); +#endif + shambase::DistributedDataShared recv_dat; shamalgs::collective::serialize_sparse_comm( diff --git a/src/tests/shamcomm/collectivesTests.cpp b/src/tests/shamcomm/collectivesTests.cpp index a4d0b732a5..4be7c99d4f 100644 --- a/src/tests/shamcomm/collectivesTests.cpp +++ b/src/tests/shamcomm/collectivesTests.cpp @@ -37,3 +37,26 @@ TestStart(Unittest, "shamcomm/collectives::gather_str", test_gather_str, 4) { REQUIRE_EQUAL(recv, result); } + +TestStart(Unittest, "shamcomm/collectives::allgather_str", test_allgather_str, 4) { + + std::array ref_base{ + "I'm a very important string", + "But I'm a very important string", + "Listen, I'm a very important string", + "The most importantest string", + }; + + std::string result = ""; + for (u32 i = 0; i < ref_base.size(); i++) { + result += ref_base[i]; + } + + std::string send = ref_base[shamcomm::world_rank()]; + + std::string recv = "random string"; // Just to check that it is overwritten + + shamcomm::allgather_str(send, recv); + + REQUIRE_EQUAL(recv, result); +}