-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnew_AstraSimNetwork.cc
More file actions
244 lines (219 loc) · 6.88 KB
/
new_AstraSimNetwork.cc
File metadata and controls
244 lines (219 loc) · 6.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#include <json/json.hpp>
#include "astra-sim/system/AstraNetworkAPI.hh"
#include "astra-sim/system/Sys.hh"
#include "extern/remote_memory_backend/analytical/AnalyticalRemoteMemory.hh"
#include <execinfo.h>
#include <stdio.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <queue>
#include <string>
#include <thread>
#include <vector>
#include "entry.h"
#include "ns3/applications-module.h"
#include "ns3/core-module.h"
#include "ns3/csma-module.h"
#include "ns3/internet-module.h"
#include "ns3/network-module.h"
using namespace std;
using namespace ns3;
using json = nlohmann::json;
class ASTRASimNetwork : public AstraSim::AstraNetworkAPI {
public:
ASTRASimNetwork(int rank) : AstraNetworkAPI(rank) {}
~ASTRASimNetwork() {}
int sim_finish() {
for (auto it = node_to_bytes_sent_map.begin();
it != node_to_bytes_sent_map.end();
it++) {
pair<int, int> p = it->first;
if (p.second == 0) {
cout << "All data sent from node " << p.first << " is " << it->second
<< "\n";
} else {
cout << "All data received by node " << p.first << " is " << it->second
<< "\n";
}
}
return 0;
}
double sim_time_resolution() {
return 0;
}
void handleEvent(int dst, int cnt) {}
AstraSim::timespec_t sim_get_time() {
AstraSim::timespec_t timeSpec;
timeSpec.time_res = AstraSim::NS;
timeSpec.time_val = 0; // Return zero as no simulation time is needed
return timeSpec;
}
virtual void sim_schedule(
AstraSim::timespec_t delta,
void (*fun_ptr)(void* fun_arg),
void* fun_arg) {
// Remove scheduling since no simulation will occur
return;
}
virtual int sim_send(
void* buffer,
uint64_t message_size,
int type,
int dst_id,
int tag,
AstraSim::sim_request* request,
void (*msg_handler)(void* fun_arg),
void* fun_arg) {
int src_id = rank;
// Use send_flow_collective for the entire workload instead of segmented events
send_flow_collective(src_id, dst_id, message_size, msg_handler, fun_arg, tag);
return 0;
}
virtual int sim_recv(
void* buffer,
uint64_t message_size,
int type,
int src_id,
int tag,
AstraSim::sim_request* request,
void (*msg_handler)(void* fun_arg),
void* fun_arg) {
int dst_id = rank;
MsgEvent recv_event =
MsgEvent(src_id, dst_id, 1, message_size, fun_arg, msg_handler);
MsgEventKey recv_event_key =
make_pair(tag, make_pair(recv_event.src_id, recv_event.dst_id));
// Handle receive events without simulation
if (received_msg_standby_hash.find(recv_event_key) !=
received_msg_standby_hash.end()) {
int received_msg_bytes = received_msg_standby_hash[recv_event_key];
if (received_msg_bytes == message_size) {
received_msg_standby_hash.erase(recv_event_key);
recv_event.callHandler();
} else if (received_msg_bytes > message_size) {
received_msg_standby_hash[recv_event_key] =
received_msg_bytes - message_size;
recv_event.callHandler();
} else {
received_msg_standby_hash.erase(recv_event_key);
recv_event.remaining_msg_bytes -= received_msg_bytes;
sim_recv_waiting_hash[recv_event_key] = recv_event;
}
} else {
if (sim_recv_waiting_hash.find(recv_event_key) ==
sim_recv_waiting_hash.end()) {
sim_recv_waiting_hash[recv_event_key] = recv_event;
} else {
int expecting_msg_bytes =
sim_recv_waiting_hash[recv_event_key].remaining_msg_bytes;
recv_event.remaining_msg_bytes += expecting_msg_bytes;
sim_recv_waiting_hash[recv_event_key] = recv_event;
}
}
return 0;
}
};
// Command line arguments and default values.
string workload_configuration;
string system_configuration;
string network_configuration;
string memory_configuration;
string comm_group_configuration;
string logical_topology_configuration;
int num_queues_per_dim = 1;
double comm_scale = 1;
double injection_scale = 1;
bool rendezvous_protocol = false;
auto logical_dims = vector<int>();
int num_npus = 1;
auto queues_per_dim = vector<int>();
void read_logical_topo_config(
string network_configuration,
vector<int>& logical_dims) {
ifstream inFile;
inFile.open(network_configuration);
if (!inFile) {
cerr << "Unable to open file: " << network_configuration << endl;
exit(1);
}
json j;
inFile >> j;
if (j.contains("logical-dims")) {
vector<string> logical_dims_str_vec = j["logical-dims"];
for (auto logical_dims_str : logical_dims_str_vec) {
logical_dims.push_back(stoi(logical_dims_str));
}
}
for (auto num_npus_per_dim : logical_dims) {
num_npus *= num_npus_per_dim;
}
}
// Read command line arguments.
void parse_args(int argc, char* argv[]) {
CommandLine cmd;
cmd.AddValue(
"workload-configuration",
"Workload configuration file.",
workload_configuration);
cmd.AddValue(
"system-configuration",
"System configuration file",
system_configuration);
cmd.AddValue(
"network-configuration",
"Network configuration file",
network_configuration);
cmd.AddValue(
"remote-memory-configuration",
"Memory configuration file",
memory_configuration);
cmd.AddValue(
"comm-group-configuration",
"Communicator group configuration file",
comm_group_configuration);
cmd.AddValue(
"logical-topology-configuration",
"Logical topology configuration file",
logical_topology_configuration);
cmd.AddValue(
"num-queues-per-dim",
"Number of queues per each dimension",
num_queues_per_dim);
cmd.AddValue("comm-scale", "Communication scale", comm_scale);
cmd.AddValue("injection-scale", "Injection scale", injection_scale);
cmd.AddValue(
"rendezvous-protocol",
"Whether to enable rendezvous protocol",
rendezvous_protocol);
cmd.Parse(argc, argv);
}
int main(int argc, char* argv[]) {
cout << "ASTRA-sim + NS3" << endl;
parse_args(argc, argv);
read_logical_topo_config(logical_topology_configuration, logical_dims);
vector<ASTRASimNetwork*> networks(num_npus, nullptr);
vector<AstraSim::Sys*> systems(num_npus, nullptr);
Analytical::AnalyticalRemoteMemory* mem =
new Analytical::AnalyticalRemoteMemory(memory_configuration);
for (int npu_id = 0; npu_id < num_npus; npu_id++) {
networks[npu_id] = new ASTRASimNetwork(npu_id);
systems[npu_id] = new AstraSim::Sys(
npu_id,
workload_configuration,
comm_group_configuration,
system_configuration,
mem,
networks[npu_id],
logical_dims,
queues_per_dim,
injection_scale,
comm_scale,
rendezvous_protocol);
}
if (auto ok = setup_ns3_simulation(network_configuration); ok == -1) {
std::cerr << "Fail to setup ns3 simulation." << std::endl;
return -1;
}
return 0;
}