-
Notifications
You must be signed in to change notification settings - Fork 2
/
ggml-mpi.h
305 lines (250 loc) · 9.66 KB
/
ggml-mpi.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
#pragma once
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
struct ggml_context;
struct ggml_tensor;
struct ggml_cgraph;
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_MPI_DECODE 0
#define GGML_MPI_KV_CLEAR 1
#define GGML_MPI_KV_SEQ_RM 2
#define GGML_MPI_KV_SEQ_CP 3
#define GGML_MPI_KV_SEQ_KEEP 4
#define GGML_MPI_KV_SEQ_SHIFT 5
#define GGML_MPI_SHUTDOWN 6
#define GGML_MPI_TRANSFER_TENSORS 7
#define GGML_MPI_SYNC_LOGITS 8
#define GGML_MPI_CANCEL_RUN 9
#define GGML_MPI_KV_SEQ_CP_BACK 10
#define GGML_MPI_TRANS_ID 11
#define GGML_MPI_BATCH_ID 12
#define GGML_MPI_N_TOKENS 13
#define GGML_MPI_TOKENS 14
#define GGML_MPI_N_SEQ_IDS 15
#define GGML_MPI_SEQ_IDS 16
#define GGML_MPI_POS 17
#define GGML_MPI_BEGIN_TRANSACTION 18
#define GGML_MPI_MAX_N_SEQ 19
#define GGML_MPI_BATCH_LOGITS 20
/**
* The context used for MPI operations,
* a program may make use of more than one
* context but must always have at least one.
*
* The context stores required information like the
* node rank and a communicator to use for MPI operations.
* A context is guaranteed to be internally consistent,
* meaning that a context's stored rank is valid within
* the context's communicator.
*/
struct ggml_mpi_context;
int ggml_mpi_trans_id(struct ggml_mpi_context * ctx_mpi);
int ggml_mpi_recv_trans_id(struct ggml_mpi_context * ctx_mpi);
void ggml_mpi_inc_trans_id(struct ggml_mpi_context * ctx_mpi);
/**
* Initialize the MPI library and the GGML MPI backend.
* Calling more than once during the lifetime of the program
* leads to undefined behavior. This function must be called before
* any MPI operations.
*/
void ggml_mpi_backend_init(void);
bool ggml_mpi_is_decoding(struct ggml_mpi_context * ctx_mpi);
int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi);
void ggml_mpi_graph_creation_post(struct ggml_mpi_context * ctx_mpi, struct ggml_cgraph * cgraph, int n_layers);
void ggml_mpi_wait_recv(struct ggml_mpi_context * ctx_mpi);
/**
* Frees the MPI backend, must be called only once at termination
* of the program. No MPI operations may be completed after calling this function,
* and attempting to do so will lead to undefined behavior.
*/
void ggml_mpi_backend_free(void);
/**
* Construct a new MPI context using the MPI_WORLD
* communicator. This is useful only to create the
* initial context, as calling multiple times
* will only create effective copies of the same data.
*
* @return A context for us in the global communicator.
*/
struct ggml_mpi_context * ggml_mpi_init(void);
/**
* Create a new context by splitting the given context's
* communicator, creating a "sub-communicator." This is a collective
* operation and must be performed by all nodes within the same communicator.
* The color and key have the same meaning as in MPI_Comm_split(), i.e.
* the color is used to determine the sub-communicator this node will belong to,
* and the key is the relative rank of this node in the new communicator.
*
* An example: if a node passes a color of 1, and a different node passes a color of 2,
* the nodes will belong to two different sub-communicators. If two nodes pass the same
* color, then their ranks will be ordered by the order of their keys. If they pass the same
* key, then the tie will be broken by the nodes' ranks in the old communicator.
*
* The communicator used by the given context remains entirely valid, so it is advisable
* to store both the old and new contexts. This allows an application to
* select at runtime which communicator to perform MPI operations with. An example
* would be to segregate the nodes into multiple domains categorized by the functions
* they perform, and use the original context to broadcast to all nodes in the cluster.
*
* @param ctx The context containing the communicator to split.
* @param color The sub-communicator that this node will belong to.
* @param key The relative rank of this node in the new communicator.
* @return A new context with all values referencing the newly-created communicator.
*/
struct ggml_mpi_context * ggml_mpi_split_comm(struct ggml_mpi_context * ctx, int color, int key);
void ggml_mpi_barrier(struct ggml_mpi_context * ctx);
int ggml_mpi_next_node(struct ggml_mpi_context * ctx_mpi);
int ggml_mpi_prev_node(struct ggml_mpi_context * ctx_mpi);
void ggml_mpi_sync_ints_pipelined(
struct ggml_mpi_context * ctx_mpi,
int32_t * vals,
int count,
int tag
);
void ggml_mpi_sync_ints_pipelined_back(
struct ggml_mpi_context * ctx_mpi,
int32_t * vals,
int count,
int tag
);
// clear = 1, rm = 2, cp = 3, keep = 4, seq_shift = 5
void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag);
int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi);
int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag);
/**
* Frees the given context, including the communicator. No MPI
* operations besides ggml_mpi_backend_freee(void) should be executed after
* running this function.
*
* @param ctx The context to free.
*/
void ggml_mpi_free(struct ggml_mpi_context * ctx);
/**
* Get the rank of this node in the given context's communicator.
*
* @param ctx The context to use to determine the rank with regards to.
* @return The rank of this node.
*/
int ggml_mpi_rank(struct ggml_mpi_context * ctx);
/**
* Get the number of nodes that are a part of
* the communicator referenced by the given context.
*
* @param ctx The context containing the communicator used for this size check.
* @return The number of nodes that are a part of the given context's communicator.
*/
size_t ggml_mpi_size(struct ggml_mpi_context * ctx);
/**
* Synchronize needed information among the nodes
* to prepare for running an evaluation iteration.
* This is a collective operation and all nodes must
* call this function. It will block until all
* nodes have entered it, to prevent any desync
* between nodes.
*
* @param ctx_mpi The context in which to prepare for evaluation.
* @param n_tokens A pointer to the n_tokens, which will be synchronized after this function.
* @param pos A pointer to the pos array, which will be synchronized after this function.
* @param n_seq_ids A pointer to the n_seq_ids array, which will be synchronized after this function.
* @param seq_id A pointer to the seq_id 2D array, which will be synchronized after this function.
* @param logits A pointer to the logits array, which is unused currently since only node 0 needs them.
*/
bool ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
int32_t * n_tokens,
int32_t ** tokens,
int32_t ** pos,
int32_t ** n_seq_ids,
int32_t *** seq_id,
int8_t ** logits,
int32_t * batch_id,
bool receive_only);
void ggml_mpi_synch_int(
struct ggml_mpi_context * ctx_mpi,
int32_t * val,
int root
);
void ggml_mpi_synch_float(
struct ggml_mpi_context * ctx_mpi,
float * val,
int root
);
void ggml_mpi_recv_float_array(
struct ggml_mpi_context * ctx_mpi,
float * val,
int arr_size,
int src,
int tag
);
void ggml_mpi_send_float_array_async(
struct ggml_mpi_context * ctx_mpi,
float * val,
int arr_size,
int dest,
int tag
);
/**
* Split a range across all nodes within the given
* context, weighting the allocations by the given weights.
* The dimensions of the returned 2d array are (number of nodes in the context, 2).
* The first element in the inner array is the starting point of the range allocated
* to the node indicated by the index into the outer array,
* and the second element is the end point of the allocated range, inclusive.
*
* @param ctx_mpi The context used to determine the number of nodes
* to split the range across.
* @param start The starting point of the range.
* @param end The end point of the range, inclusive.
* @param node_weights How to weight the allocations across the nodes,
* must sum to 1.0.
* @return A 2d array, the first dimension is the number of nodes in the context
* and the second dimension is 2.
*/
uint16_t** ggml_mpi_split_range(
struct ggml_mpi_context * ctx_mpi,
uint16_t start,
uint16_t end,
float node_weights[]
);
/**
* Scatter the layer ranges across all nodes
* in the given context. This is a collective operation
* and must be called by all nodes that are within the same
* communicator. The given layer ranges must be in the same
* format as created by the ggml_mpi_split_range().
*
* @param ctx_mpi The context to scatter the layers across.
* @param layer_ranges The pre-split ranges to scatter to the nodes.
*/
void ggml_mpi_scatter_layers(
struct ggml_mpi_context * ctx_mpi,
uint16_t ** layer_ranges
);
/**
* Modify compute graph to only process allocated
* layers.
*
* @param ctx_mpi The context containing the allocated layer range.
* @param gf The compute graph to modify
* @param n_layers The number of layers in the model, used as an upper bound in the layer ranges.
*/
bool ggml_mpi_graph_compute_pre(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf);
/**
* Sends the output tensor to the next node for processing
* of later layers.
*
* @param ctx_mpi The context to use for MPI operations.
* @param gf The graph used in the computations
* @param n_layers The number of layers in the model.
*/
void ggml_mpi_graph_compute_post(
struct ggml_mpi_context * ctx_mpi,
struct ggml_cgraph * gf);
#ifdef __cplusplus
}
#endif