-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcontext.h
370 lines (287 loc) · 12.8 KB
/
context.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
// Copyright 2019 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <spdlog/common.h>
#include <atomic>
#include <cstdint>
#include <limits>
#include <map>
#include <string>
#include <vector>
#include "yacl/base/byte_container_view.h"
#include "yacl/link/retry_options.h"
#include "yacl/link/ssl_options.h"
#include "yacl/link/transport/channel.h"
#include "yacl/utils/hash_combine.h"
#include "yacl/link/link.pb.h"
namespace yacl::link {
constexpr size_t kAllRank = std::numeric_limits<size_t>::max();
struct ContextDesc {
static constexpr char kDefaultId[] = "root";
static constexpr uint32_t kDefaultConnectRetryTimes = 10;
static constexpr uint32_t kDefaultConnectRetryIntervalMs = 1000; // 1 second.
static constexpr uint64_t kDefaultRecvTimeoutMs = 30 * 1000; // 30s
static constexpr uint32_t kDefaultHttpMaxPayloadSize =
1024 * 1024; // 1M Bytes
static constexpr uint32_t kDefaultHttpTimeoutMs = 20 * 1000; // 20 seconds.
static constexpr uint32_t kDefaultThrottleWindowSize = 10;
static constexpr uint32_t kDefaultChunkParallelSendSize = 8;
static constexpr char kDefaultBrpcChannelProtocol[] = "baidu_std";
static constexpr char kDefaultLinkType[] = "normal";
struct Party {
std::string id;
std::string host;
bool operator==(const Party& p) const {
return (id == p.id) && (host == p.host);
}
Party() = default;
Party(const PartyProto& pb) : id(pb.id()), host(pb.host()) {}
Party(const std::string& id_, const std::string& host_)
: id(id_), host(host_) {}
};
// the UUID of this communication.
std::string id = kDefaultId;
// party description, describes the world.
std::vector<Party> parties;
// connect to mesh retry time.
uint32_t connect_retry_times = kDefaultConnectRetryTimes;
// connect to mesh retry interval.
uint32_t connect_retry_interval_ms =
kDefaultConnectRetryIntervalMs; // 1 second.
// recv timeout in milliseconds.
//
// 'recv time' is the max time that a party will wait for a given event.
// for example:
//
// begin recv end recv
// |--------|-------recv-time----------|------------------| alice's timeline
//
// begin send end send
// |-----busy-work-------------|-------------|------------| bob's timeline
//
// in above case, when alice begins recv for a specific event, bob is still
// busy doing its job, when alice's wait time exceed wait_timeout_ms, it raise
// exception, although bob now is starting to send data.
//
// so for long time work(that one party may wait for the others for very long
// time), this value should be changed accordingly.
uint64_t recv_timeout_ms = kDefaultRecvTimeoutMs; // 30s
// http max payload size, if a single http request size is greater than this
// limit, it will be unpacked into small chunks then reassembled.
//
// This field does affect performance. Please choose wisely.
uint32_t http_max_payload_size = kDefaultHttpMaxPayloadSize; // 1M Bytes
// a single http request timetout.
uint32_t http_timeout_ms = kDefaultHttpTimeoutMs; // 20 seconds.
// throttle window size for channel. if there are more than limited size
// messages are flying, `SendAsync` will block until messages are processed or
// throw exception after wait for `recv_timeout_ms`
uint32_t throttle_window_size = kDefaultThrottleWindowSize;
// chunk parallel send size for channel. if need chunked send when send
// message, the max paralleled send size is chunk_parallel_send_size
uint32_t chunk_parallel_send_size = kDefaultChunkParallelSendSize;
// BRPC client channel protocol.
std::string brpc_channel_protocol = kDefaultBrpcChannelProtocol;
// BRPC client channel connection type.
std::string brpc_channel_connection_type = "";
// ssl options for link channel
bool enable_ssl = false;
// ssl options for link channel
// this option is ignored if enable_ssl == false;
SSLOptions client_ssl_opts;
// ssl options for link service
// this option is ignored if enable_ssl == false;
SSLOptions server_ssl_opts;
// if true, process will exit(-1) when error happened in link async operate
// otherwise, only log error.
bool exit_if_async_error = true;
// "blackbox" or "normal", default: "normal"
std::string link_type = kDefaultLinkType;
RetryOptions retry_opts;
bool disable_msg_seq_id = false;
bool operator==(const ContextDesc& other) const {
return (id == other.id) && (parties == other.parties);
}
ContextDesc() = default;
ContextDesc(const ContextDescProto& pb)
: id(pb.id().size() ? pb.id() : kDefaultId),
connect_retry_times(pb.connect_retry_times()
? pb.connect_retry_times()
: kDefaultConnectRetryTimes),
connect_retry_interval_ms(pb.connect_retry_interval_ms()
? pb.connect_retry_interval_ms()
: kDefaultConnectRetryIntervalMs),
recv_timeout_ms(pb.recv_timeout_ms() ? pb.recv_timeout_ms()
: kDefaultRecvTimeoutMs),
http_max_payload_size(pb.http_max_payload_size()
? pb.http_max_payload_size()
: kDefaultHttpMaxPayloadSize),
http_timeout_ms(pb.http_timeout_ms() ? pb.http_timeout_ms()
: kDefaultHttpTimeoutMs),
throttle_window_size(pb.throttle_window_size()
? pb.throttle_window_size()
: kDefaultThrottleWindowSize),
chunk_parallel_send_size(pb.chunk_parallel_send_size()
? pb.chunk_parallel_send_size()
: kDefaultChunkParallelSendSize),
brpc_channel_protocol(pb.brpc_channel_protocol().size()
? pb.brpc_channel_protocol()
: kDefaultBrpcChannelProtocol),
brpc_channel_connection_type(pb.brpc_channel_connection_type()),
enable_ssl(pb.enable_ssl()),
client_ssl_opts(pb.client_ssl_opts()),
server_ssl_opts(pb.server_ssl_opts()),
link_type(kDefaultLinkType),
retry_opts(pb.retry_opts()) {
for (const auto& party_pb : pb.parties()) {
parties.emplace_back(party_pb);
}
}
};
struct ContextDescHasher {
size_t operator()(const ContextDesc& desc) const {
size_t seed = 0;
utils::hash_combine(seed, desc.id);
for (const auto& p : desc.parties) {
utils::hash_combine(seed, p.id, p.host);
}
utils::hash_combine(seed, desc.connect_retry_times,
desc.connect_retry_interval_ms, desc.recv_timeout_ms,
desc.http_max_payload_size, desc.http_timeout_ms,
desc.throttle_window_size, desc.brpc_channel_protocol,
desc.brpc_channel_connection_type, desc.link_type);
return seed;
}
};
struct Statistics {
// total number of data sent in bytes, excluding key
std::atomic<size_t> sent_bytes = 0U;
// total number of sent actions, chunked mode is treated as a single action.
std::atomic<size_t> sent_actions = 0U;
// total number of data received in bytes, excluding key.
std::atomic<size_t> recv_bytes = 0U;
// total number of recv actions, chunked mode is treated as a single action.
std::atomic<size_t> recv_actions = 0U;
};
// Threading: link context could only be used in one thread, since
// communication rounds are identified by (incremental) counters.
//
// Spawn it if you need to use it in a different thread, the
// channels/event_loop will be shared between parent/child contexts.
class Context {
public:
Context(ContextDesc desc, size_t rank,
std::vector<std::shared_ptr<transport::IChannel>> channels,
std::shared_ptr<transport::IReceiverLoop> msg_loop,
bool is_sub_world = false);
Context(const ContextDescProto& desc_pb, size_t rank,
std::vector<std::shared_ptr<transport::IChannel>> channels,
std::shared_ptr<transport::IReceiverLoop> msg_loop,
bool is_sub_world = false);
std::string Id() const;
size_t WorldSize() const;
size_t Rank() const;
size_t NextRank(size_t stride = 1) const;
size_t PrevRank(size_t stride = 1) const;
// P2P algorithms
void SendAsync(size_t dst_rank, ByteContainerView value,
std::string_view tag);
void SendAsync(size_t dst_rank, Buffer&& value, std::string_view tag);
void SendAsyncThrottled(size_t dst_rank, ByteContainerView value,
std::string_view tag);
void SendAsyncThrottled(size_t dst_rank, Buffer&& value,
std::string_view tag);
void Send(size_t dst_rank, ByteContainerView value, std::string_view tag);
Buffer Recv(size_t src_rank, std::string_view tag);
// Connect to mesh, you can also set the connect log level to any
// spdlog::level
void ConnectToMesh(
spdlog::level::level_enum connect_log_level = spdlog::level::debug);
std::unique_ptr<Context> Spawn(const std::string& id = "");
// Create a new Context from a subset of original parities.
// Party which not in `sub_parties` should not call the SubWorld() method.
// `id_suffix` will append to original context id as new context id
std::unique_ptr<Context> SubWorld(
std::string_view id_suffix,
const std::vector<std::string>& sub_party_ids);
void SetRecvTimeout(uint64_t recv_timeout_ms);
uint64_t GetRecvTimeout() const;
void WaitLinkTaskFinish();
void AbortLink();
void SetThrottleWindowSize(size_t);
void SetChunkParallelSendSize(size_t);
// for internal algorithms.
void SendAsyncInternal(size_t dst_rank, const std::string& key,
ByteContainerView value);
void SendAsyncInternal(size_t dst_rank, const std::string& key,
Buffer&& value);
void SendAsyncThrottledInternal(size_t dst_rank, const std::string& key,
ByteContainerView value);
void SendAsyncThrottledInternal(size_t dst_rank, const std::string& key,
Buffer&& value);
void SendInternal(size_t dst_rank, const std::string& key,
ByteContainerView value);
Buffer RecvInternal(size_t src_rank, const std::string& key);
// next collective algorithm id.
std::string NextId();
std::string PartyIdByRank(size_t rank) { return desc_.parties[rank].id; }
// next P2P comm id.
std::string NextP2PId(size_t src_rank, size_t dst_rank);
// for external message loop
std::shared_ptr<transport::IChannel> GetChannel(size_t src_rank) const;
// print statistics
void PrintStats();
// get statistics
std::shared_ptr<const Statistics> GetStats() const;
protected:
using P2PDirection = std::pair<int, int>;
const ContextDesc desc_; // world description.
const size_t rank_; // my rank.
const std::vector<std::shared_ptr<transport::IChannel>> channels_;
const std::shared_ptr<transport::IReceiverLoop> receiver_loop_;
// stateful properties.
size_t counter_ = 0U; // collective algorithm counter.
std::map<P2PDirection, int> p2p_counter_;
size_t child_counter_ = 0U;
uint64_t recv_timeout_ms_;
// sub-context will shared statistics with parent
std::shared_ptr<Statistics> stats_;
const bool is_sub_world_;
};
// a RecvTimeoutGuard is to help set the recv timeout value for the Context.
// for example:
// {
// RecvTimeoutGuard guard(ctx, timeout);
// method();
// }
// in above case, the Context's recv_timeout_ms_ is set to timout before the
// method and recovers to its original value automatically after the
// method finishes.
class RecvTimeoutGuard {
public:
// set recv timeout and save original value
RecvTimeoutGuard(const std::shared_ptr<Context>& ctx,
uint64_t recv_timeout_ms)
: ctx_(ctx), recv_timeout_ms_(ctx->GetRecvTimeout()) {
ctx->SetRecvTimeout(recv_timeout_ms);
}
// recover original timeout value
~RecvTimeoutGuard() { ctx_->SetRecvTimeout(recv_timeout_ms_); }
RecvTimeoutGuard(const RecvTimeoutGuard&) = delete;
RecvTimeoutGuard& operator=(const RecvTimeoutGuard&) = delete;
private:
const std::shared_ptr<Context>& ctx_;
uint64_t recv_timeout_ms_;
};
} // namespace yacl::link