Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ set(PCMS_HEADERS
pcms/field_communicator.h
pcms/field_communicator2.h
pcms/field_evaluation_methods.h
pcms/global_communicator.h
pcms/memory_spaces.h
pcms/types.h
pcms/array_mask.h
Expand Down Expand Up @@ -51,11 +52,13 @@ if(PCMS_ENABLE_OMEGA_H)
list(
APPEND
PCMS_HEADERS
pcms/adapter/omega_h/omega_h_field.h
#pcms/adapter/omega_h/omega_h_field.h
pcms/transfer_field.h
pcms/transfer_field2.h
pcms/uniform_grid.h
pcms/point_search.h)
install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/pcms/adapter/omega_h/omega_h_field.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/pcms/adapter/omega_h)
endif()

find_package(Kokkos REQUIRED)
Expand Down
32 changes: 32 additions & 0 deletions src/pcms/coupler.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#ifndef PCMS_COUPLER_H
#define PCMS_COUPLER_H
#include "global_communicator.h"
#include "pcms/common.h"
#include "pcms/field_communicator.h"
#include "pcms/adapter/omega_h/omega_h_field.h"
Expand Down Expand Up @@ -112,7 +113,32 @@ class CoupledField
private:
std::unique_ptr<CoupledFieldConcept> coupled_field_;
};
template <typename T>
class GlobalDataInterface
{
public:
GlobalDataInterface( const std::string& name , MPI_Comm mpi_comm, redev::Channel& channel)
: mpi_comm_(mpi_comm), comm_(GlobalCommunicator<T>(name, mpi_comm_, channel)),
type_info_(typeid(T))
{
PCMS_FUNCTION_TIMER;
}
void Send(T* msg, std::string VarName, size_t msg_size, Mode mode = Mode::Synchronous)
{
PCMS_FUNCTION_TIMER;
comm_.Send(msg, VarName, msg_size, mode);
}
std::vector<T> Receive(std::string VarName, size_t msg_size, Mode mode = Mode::Synchronous)
{
PCMS_FUNCTION_TIMER;
return comm_.Receive(VarName, msg_size, mode);
}
private:
MPI_Comm mpi_comm_;
const std::type_info& type_info_;
GlobalCommunicator<T> comm_;

};
class Application
{
public:
Expand Down Expand Up @@ -142,6 +168,12 @@ class Application
}
return &(it->second);
}
template <typename T>
std::unique_ptr<GlobalDataInterface<T>> Add_GDI(std::string name, MPI_Comm mpi_comm)
{
PCMS_FUNCTION_TIMER;
return std::make_unique<GlobalDataInterface<T>>(name, mpi_comm, channel_); // Use the existing applivatiocation channel
}
void SendField(const std::string& name, Mode mode = Mode::Synchronous)
{
PCMS_FUNCTION_TIMER;
Expand Down
49 changes: 49 additions & 0 deletions src/pcms/global_communicator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef PCMS_GLOBAL_COMMUNICATOR_H
#define PCMS_GLOBAL_COMMUNICATOR_H
#endif // PCMS_GLOBAL_COMMUNICATOR_H

#include <redev.h>

namespace pcms
{
using redev::Mode;
template <typename T>
struct GlobalCommunicator
{
using value_type = T;
public:
GlobalCommunicator(std::string name, MPI_Comm mpi_comm, redev::Channel& channel)
: mpi_comm(mpi_comm),
channel_(channel),
name_(std::move(name))
{
PCMS_FUNCTION_TIMER;
comm_ = channel_.CreateComm<T>(name_, mpi_comm, redev::CommType::Global );
}
GlobalCommunicator(const GlobalCommunicator&) = delete;
GlobalCommunicator& operator=(const GlobalCommunicator&) = delete;
GlobalCommunicator(GlobalCommunicator&&)= default;
GlobalCommunicator& operator=(GlobalCommunicator&&) = default;

void Send(T* msg, std::string VarName, size_t msg_size, Mode mode = Mode::Synchronous)
{
PCMS_FUNCTION_TIMER;
PCMS_ALWAYS_ASSERT(channel_.InSendCommunicationPhase());
comm_.SetCommParams( VarName, msg_size);
comm_.Send(msg, mode);
}
std::vector<T> Receive(std::string VarName, size_t msg_size, Mode mode = Mode::Synchronous)
{
PCMS_FUNCTION_TIMER;
PCMS_ALWAYS_ASSERT(channel_.InReceiveCommunicationPhase());
comm_.SetCommParams(VarName, msg_size);
auto data = comm_.Recv(mode);
return data;
}
private:
MPI_Comm mpi_comm;
redev::Channel& channel_;
std::string name_;
redev::BidirectionalComm<T> comm_;
};
}
32 changes: 31 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,37 @@ if(PCMS_ENABLE_OMEGA_H)
${d3d16p}
ignored)
endif()

add_exe(test_GDI)
tri_mpi_test(
TESTNAME
test_GDI
TIMEOUT
20
NAME1
app
EXE1
./test_GDI
PROCS1
1
ARGS1
1
NAME2
rdv
EXE2
./test_GDI
PROCS2
1
ARGS2
-1
NAME3
app
EXE3
./test_GDI
PROCS3
1
ARGS3
0
)
set(d3d8p ${PCMS_TEST_DATA_DIR}/d3d/d3d-full_9k_sfc_p8.osh/)
add_exe(test_twoClientOverlap)
if(HOST_NPROC GREATER_EQUAL 28)
Expand Down
130 changes: 130 additions & 0 deletions test/test_GDI.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#include <Omega_h_mesh.hpp>
#include <iostream>
#include <pcms.h>
#include <pcms/types.h>
#include <Omega_h_file.hpp>
#include "test_support.h"
#include "pcms/adapter/omega_h/omega_h_field.h"

static constexpr bool done = true;
static constexpr int COMM_ROUNDS = 1;

void xgc_delta_f(MPI_Comm comm)
{
pcms::Coupler coupler("proxy_couple", comm, false, {});
pcms::Application* app = coupler.AddApplication("proxy_couple_xgc_delta_f");

const auto GDI = app->Add_GDI<pcms::GO>("global_comm", comm);
auto mean = std::vector<pcms::GO>(1);
mean[0] = 16;
do {
for (int i = 0; i < COMM_ROUNDS; ++i) {
app->BeginSendPhase();
GDI->Send(mean.data(), "mean", mean.size());
app->EndSendPhase();
printf("delta Sent mean:%d\n", mean[0]);
app->BeginReceivePhase();
mean = GDI->Receive("mean", mean.size());
app->EndReceivePhase();
mean[0] = mean[0]/2;
}
} while (!done);
printf("final Mean = %d\n", mean[0]);
assert(std::fabs(mean[0] - 1.0) < 1e-12);
printf("GDI test successful.\n");
}
void xgc_total_f(MPI_Comm comm)
{
pcms::Coupler coupler("proxy_couple", comm, false, {});
pcms::Application* app = coupler.AddApplication("proxy_couple_xgc_total_f");

auto GDI = app->Add_GDI<pcms::GO>("global_comm", comm);
auto mean = std::vector<pcms::GO>(1);
do {
for (int i = 0; i < COMM_ROUNDS; ++i) {
app->BeginReceivePhase();
mean = GDI->Receive("mean", mean.size());
app->EndReceivePhase();
printf("total Recieved mean:%d\n", mean[0]);
mean[0] = mean[0]/2;
app->BeginSendPhase();
GDI->Send(mean.data(), "mean", mean.size());
app->EndSendPhase();
printf("total Sent mean:%d\n", mean[0]);
}
} while (!done);
}
void xgc_coupler(MPI_Comm comm)
{
// Define Partition
redev::LO dim = 3;
redev::LOs ranks(1);
std::iota(ranks.begin(), ranks.end(), 0);
redev::Reals cuts = {0};
auto partition = redev::Partition{redev::RCBPtn{dim, ranks, cuts}};

pcms::Coupler cpl("proxy_couple", comm, true,
partition);
auto* total_f = cpl.AddApplication("proxy_couple_xgc_total_f");
auto* delta_f = cpl.AddApplication("proxy_couple_xgc_delta_f");

auto GDI_total = total_f->Add_GDI<pcms::GO>("global_comm", comm);
auto GDI_delta = delta_f->Add_GDI<pcms::GO>("global_comm", comm);
auto mean = std::vector<pcms::GO>(1);
do {
for (int i = 0; i < COMM_ROUNDS; ++i) {
delta_f->BeginReceivePhase();
mean = GDI_delta->Receive("mean", 1);
delta_f->EndReceivePhase();
printf("delta Received mean:%d\n", mean[0]);
mean[0] = mean[0]/2;
const auto msg_size = mean.size();
total_f->BeginSendPhase();
GDI_total->Send(mean.data(), "mean", msg_size);
total_f->EndSendPhase();
printf("total sent mean:%d\n", mean[0]);
total_f->BeginReceivePhase();
mean = GDI_total->Receive("mean", msg_size);
total_f->EndReceivePhase();
printf("delta Received mean:%d\n", mean[0]);
mean[0] = mean[0]/2;
delta_f->BeginSendPhase();
GDI_delta->Send(mean.data(), "mean", msg_size);
delta_f->EndSendPhase();
printf("detla sent mean:%d\n", mean[0]);
}
} while (!done);
}

int main(int argc, char** argv)
{
MPI_Init(&argc, &argv); // MPI init

OMEGA_H_CHECK(argc == 2);
const auto clientId = atoi(argv[1]);
REDEV_ALWAYS_ASSERT(clientId >= -1 && clientId <= 1);

int color;
if (clientId == -1)
color = 0; // coupler
else if (clientId == 0)
color = 1; // client A
else if (clientId == 1)
color = 2; // client B
else
color = MPI_UNDEFINED;

MPI_Comm subcomm;
MPI_Comm_split(MPI_COMM_WORLD, color, 0, &subcomm);

switch (clientId) {
case -1: xgc_coupler(subcomm); break;
case 0: xgc_delta_f(subcomm); break;
case 1: xgc_total_f(subcomm); break;
default:
std::cerr << "Unhandled client id (should be -1, 0,1)\n";
exit(EXIT_FAILURE);
}
MPI_Finalize();
return 0;
}