-
Notifications
You must be signed in to change notification settings - Fork 6
/
receive.cpp
202 lines (171 loc) · 6.43 KB
/
receive.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#include <mpi.h>
#include <vector>
#include <memory>
#include <henson/data.h>
#include <henson/context.h>
#include <fmt/format.h>
#include <fmt/ostream.h>
#include <opts/opts.h>
#include "common.hpp"
// http://stackoverflow.com/a/1486931/44738
#define UNUSED(expr) do { (void)(expr); } while (0)
#define READ_TYPE(VARTYPE) \
if (var.type == #VARTYPE) \
{ \
VARTYPE x; \
read(buffer, position, x); \
henson_save_##VARTYPE(var.name.c_str(), x); \
}
int main(int argc, char** argv)
{
using namespace opts;
Options ops;
bool async, help;
ops
>> Option('a', "async", async, "asynchronous mode")
>> Option('h', "help", help, "show help");
std::string remote_group;
if (!ops.parse(argc,argv) || help || !(ops >> PosOption(remote_group)))
{
fmt::print("Usage: {} REMOTE_GROUP [variables]+\n{}", argv[0], ops);
return 1;
}
std::vector<Variable> variables;
std::string var;
while (ops >> PosOption(var))
variables.push_back(parse_variable(var));
if (!henson_active())
{
fmt::print("Must run under henson, but henson is not active\n");
return 1;
}
// Setup communicators
MPI_Comm local = henson_get_world();
int rank, size;
MPI_Comm_rank(local, &rank);
MPI_Comm_size(local, &size);
MPI_Comm remote = henson_get_intercomm(remote_group.c_str());
int remote_size;
MPI_Comm_remote_size(remote, &remote_size);
// Figure out partner ranks
std::vector<int> ranks;
if (size >= remote_size)
{
if (size % remote_size != 0)
{
if (rank == 0)
fmt::print("[receive]: group size must be divisible by remote size (or vice versa), got {} vs {}\n", size, remote_size);
return 1;
}
ranks.push_back(rank / (size / remote_size));
} else if (size < remote_size)
{
if (remote_size % size != 0)
{
if (rank == 0)
fmt::print("[receive]: remote size must be divisible by the group size (or vice versa), got {} vs {}\n", size, remote_size);
return 1;
}
int fraction = remote_size / size;
for (int i = 0; i < fraction; ++i)
ranks.push_back(rank*fraction + i);
}
size_t array_count = 0;
for (const Variable& var : variables)
if (var.type == "array")
array_count += split(var.name, ',').size();
while(true)
{
MPI_Status s;
// request more data
if (async && rank == 0)
MPI_Send(0, 0, MPI_INT, rank, tags::request_data, remote);
// check if we are told to stop
// TODO: this loop forces us to wait until there is a message waiting from every rank
// we communicate with; in general, this is not great
int stop;
if (rank == 0)
{
MPI_Probe(rank, MPI_ANY_TAG, remote, &s);
MPI_Iprobe(rank, tags::stop, remote, &stop, &s);
if (stop)
{
fmt::print("[{}]: stop signal in receive\n", rank);
MPI_Recv(0, 0, MPI_INT, rank, tags::stop, remote, &s); // unblock the send
}
MPI_Bcast(&stop,1,MPI_INT,0,local);
} else
MPI_Bcast(&stop,1,MPI_INT,0,local);
if (stop)
return 0;
std::vector<std::vector<char>> buffers(ranks.size());
for (size_t i = 0; i < ranks.size(); ++i)
{
int rank = ranks[i];
auto& buffer = buffers[i];
int c;
MPI_Probe(rank, tags::data, remote, &s);
MPI_Get_count(&s, MPI_BYTE, &c);
buffer.resize(c);
MPI_Recv(&buffer[0], buffer.size(), MPI_BYTE, rank, tags::data, remote, &s);
}
std::vector<std::vector<char>> arrays(array_count);
std::vector<std::tuple<size_t, size_t>> arrays_meta(array_count, std::make_tuple<size_t,size_t>(0,0)); // (count, type)
for (size_t i = 0; i < ranks.size(); ++i)
{
int rank = ranks[i]; UNUSED(rank);
auto& buffer = buffers[i];
size_t position = 0;
size_t array_idx = 0;
for (const Variable& var : variables)
{
READ_TYPE(int) else
READ_TYPE(size_t) else
READ_TYPE(float) else
READ_TYPE(double) else
if (var.type == "array")
{
for (auto name : split(var.name, ','))
{
size_t count; size_t type;
read(buffer, position, count);
read(buffer, position, type);
void* data = &buffer[position];
if (ranks.size() == 1)
henson_save_array(name.c_str(), data, type, count, type); // save directly
else
{
// copy the data
auto& array = arrays[array_idx];
size_t sz = array.size();
array.resize(sz + count*type);
std::copy((char*) data, (char*) data + count*type, &array[sz]);
std::get<0>(arrays_meta[array_idx]) += count;
std::get<1>(arrays_meta[array_idx]) = type;
}
position += count*type;
++array_idx;
}
} else
fmt::print("Warning: unknown type {} for {}\n", var.type, var.name);
}
if (ranks.size() != 1)
std::vector<char>().swap(buffers[i]); // wipe out the buffer that we no longer need
}
if (ranks.size() != 1)
{
// save the arrays
size_t array_idx = 0;
for (const Variable& var : variables)
if (var.type == "array")
for (auto name : split(var.name, ','))
{
size_t count; size_t type;
std::tie(count,type) = arrays_meta[array_idx];
henson_save_array(name.c_str(), &arrays[array_idx][0], type, count, type);
++array_idx;
}
}
henson_yield();
}
}