// Align buffers to 32 bytes to support vectorized code constsize_t kBufferAlignment = 32;
template <typename T, int ALIGNMENT = kBufferAlignment> class aligned_allocator { static_assert( !(ALIGNMENT & (ALIGNMENT - 1)), "alignment must be a power of 2");
public: using value_type = T; using pointer = value_type*; using const_pointer = const value_type*; using reference = value_type&; using const_reference = const value_type&; using size_type = std::size_t; using difference_type = std::ptrdiff_t;
template <typename U> structrebind { using other = aligned_allocator<U, ALIGNMENT>; };
intmain(int/*argc*/, char** /*argv*/){ // We'll use the TCP transport in this example auto dev = gloo::transport::tcp::CreateDevice("localhost");
// Create Gloo context and delegate management of MPI_Init/MPI_Finalize auto context = gloo::mpi::Context::createManaged(); context->connectFullMesh(dev);
// Create and run simple allreduce int rank = context->rank; gloo::AllreduceRing<int> allreduce(context, {&rank}, 1); allreduce.run(); std::cout << "Result: " << rank << std::endl;
// The ibv_req_notify(3) function takes an argument called // 'solicited_only' which makes it only trigger a notification for // work requests that are flagged as solicited. Every completion // should trigger a notification, so always pass 0. staticconstexprauto kNotifyOnAnyCompletion = 0;
// Send from the specified buffer to remote side of pair. virtualvoidsend( transport::UnboundBuffer* tbuf, uint64_t tag, size_t offset, size_t nbytes)override;
// Receive into the specified buffer from the remote side of pair. virtualvoidrecv( transport::UnboundBuffer* tbuf, uint64_t tag, size_t offset, size_t nbytes)override;
// Completions on behalf of buffers need to be forwarded to those buffers. std::map<int, Buffer*> sendCompletionHandlers_; std::map<int, Buffer*> recvCompletionHandlers_;
voidsendMemoryRegion(struct ibv_mr* mr, int slot); conststructibv_mr* getMemoryRegion(int slot);
// Populate local address. // The Packet Sequence Number field (PSN) is random which makes that // the remote end of this pair needs to have the contents of the // full address struct in order to connect, and vice versa. { structibv_port_attr attr; memset(&attr, 0, sizeof(struct ibv_port_attr)); rv = ibv_query_port(dev_->context_, dev_->attr_.port, &attr); GLOO_ENFORCE_EQ(rv, 0); rv = ibv_query_gid( dev_->context_, dev_->attr_.port, dev_->attr_.index, &self_.addr_.ibv_gid); GLOO_ENFORCE_EQ(rv, 0); self_.addr_.lid = attr.lid; self_.addr_.qpn = qp_->qp_num; self_.addr_.psn = rand() & 0xffffff; }
// 在连接之前发布接收请求。 // 每当这pair的远程端注册接收缓冲区时,就会触发它们的内存注册被发送到这一端。 // 由于这些发送是单方面的,我们总是需要一整套接收工作请求。 // 内存区域接收可以与常规缓冲区写入交错,因此我们主动在每个接收工作请求中包含一个内存区域。 for (int i = 0; i < kMaxBuffers; ++i) { mappedRecvRegions_[i] = make_unique<MemoryRegion>(dev_->pd_); postReceive(); } }
Pair::~Pair() { int rv;
// Acknowledge number of completion events handled by this // pair's completion queue (also see ibv_get_cq_event(3)). ibv_ack_cq_events(cq_, completionEventsHandled_);
// Move to Ready To Send (RTS) state rv = ibv_modify_qp( qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); GLOO_ENFORCE_EQ(rv, 0); }
// Switches the pair into synchronous mode. // // Note: busy polling is NOT optional. Currently, since all pairs // share a single completion channel, busy polling is mandatory // through ibv_poll_cq(3). If a use case comes up for supporting // synchronous mode where the calling thread should be suspended, this // can be revisited and we can add a completion channel per pair. // voidPair::setSync(bool sync, bool busyPoll){ checkErrorState(); if (!sync) { GLOO_THROW_INVALID_OPERATION_EXCEPTION("Can only switch to sync mode"); } if (!busyPoll) { GLOO_THROW_INVALID_OPERATION_EXCEPTION( "The ibverbs transport only supports busy polling in sync mode"); }
// The notification mechanism for this pair's completion queue is // still armed. This means the device thread will still call // handleCompletions() one more time, but this is ignored. // // No need to lock a mutex; these are atomics. // sync_ = true; busyPoll_ = true; }
// Send from the specified buffer to remote side of pair. voidPair::send( transport::UnboundBuffer* tbuf, uint64_t/* unused */, size_t/* unused */, size_t/* unused */){ GLOO_THROW_INVALID_OPERATION_EXCEPTION( "Unbound buffers not supported yet for ibverbs transport"); }
// Receive into the specified buffer from the remote side of pair. voidPair::recv( transport::UnboundBuffer* tbuf, uint64_t/* unused */, size_t/* unused */, size_t/* unused */){ GLOO_THROW_INVALID_OPERATION_EXCEPTION( "Unbound buffers not supported yet for ibverbs transport"); }
// handleCompletionEvent is called by the device thread when it // received an event for this pair's completion queue on its // completion channel. voidPair::handleCompletionEvent(){ int rv;
completionEventsHandled_++;
// If in sync mode, the pair was just switched and this is // the last notification from the device thread because // the notification mechanism is not re-armed below. if (sync_) { return; }
try { checkErrorState();
// Arm notification mechanism for completion queue. rv = ibv_req_notify_cq(cq_, kNotifyOnAnyCompletion); GLOO_ENFORCE_EQ(rv, 0);
// Now poll for work completions to drain the completion queue. std::unique_lock<std::mutex> lock(m_); pollCompletions(); } catch (const ::gloo::IoException&) { // Catch IO exceptions on the event handling thread. The exception has // already been saved and user threads signaled. } }
// Invoke handler for every work completion. for (;;) { auto nwc = ibv_poll_cq(cq_, wc.size(), wc.data()); GLOO_ENFORCE_GE(nwc, 0);
// Handle work completions for (int i = 0; i < nwc; i++) { checkErrorState(); handleCompletion(&wc[i]); }
// Break unless wc was filled if (nwc == 0 || nwc < wc.size()) { break; } } }
voidPair::handleCompletion(struct ibv_wc* wc){ if (wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) { // Incoming RDMA write completed. // Slot is encoded in immediate data on receive work completion. // It is set in the Pair::send function. auto slot = wc->imm_data; GLOO_ENFORCE_EQ( wc->status, IBV_WC_SUCCESS, "Recv for slot ", slot, ": ", ibv_wc_status_str(wc->status));
// Backfill receive work requests. postReceive(); } elseif (wc->opcode == IBV_WC_RDMA_WRITE) { // Outbound RDMA write completed. // Slot is encoded in wr_id fields on send work request. Unlike // the receive work completions, the immediate data field on send // work requests are not pass to the respective work completion. auto slot = wc->wr_id; GLOO_ENFORCE_EQ( wc->status, IBV_WC_SUCCESS, "Send for slot ", slot, ": ", ibv_wc_status_str(wc->status));
// Move ibv_mr from memory region 'inbox' to final slot. constauto& mr = mappedRecvRegions_[recvPosted_ % kMaxBuffers]; peerMemoryRegions_[slot] = mr->mr();
// Notify any buffer waiting for the details of its remote peer. cv_.notify_all();
// Backfill receive work requests. postReceive(); } elseif (wc->opcode == IBV_WC_SEND) { // Memory region send completed. auto slot = wc->wr_id; GLOO_ENFORCE_EQ( wc->status, IBV_WC_SUCCESS, "Memory region send for slot ", slot, ": ", ibv_wc_status_str(wc->status));
structibv_send_wr* bad_wr; auto rv = ibv_post_send(qp_, &wr, &bad_wr); if (rv != 0) { signalIoFailure(GLOO_ERROR_MSG("ibv_post_send: ", rv)); } }
voidPair::signalIoFailure(const std::string& msg){ std::lock_guard<std::mutex> lock(m_); auto ex = ::gloo::IoException(msg); if (ex_ == nullptr) { // If we haven't seen an error yet, store the exception to throw on future calling threads. ex_ = std::make_exception_ptr(ex); // Loop through the completion handlers and signal that an error has // occurred. for (auto& it : recvCompletionHandlers_) { GLOO_ENFORCE(it.second != nullptr); it.second->signalError(ex_); } for (auto& it : sendCompletionHandlers_) { GLOO_ENFORCE(it.second != nullptr); it.second->signalError(ex_); } } // Finally, throw the exception on this thread. throw ex; };
voidPair::checkErrorState(){ // If we previously encountered an error, rethrow here. if (ex_ != nullptr) { std::rethrow_exception(ex_); } }
// Provide hint if the error is EFAULT and nv_peer_mem is not loaded if (mr_ == nullptr && errno == EFAULT) { if (!pair->dev_->hasNvPeerMem_) { GLOO_ENFORCE( mr_ != nullptr, "ibv_reg_mr: ", strerror(errno), " (kernel module 'nv_peer_mem' not loaded;" " did you specify a pointer to GPU memory?)"); } }
// Provide hint if the error is ENOMEM if (mr_ == nullptr && errno == ENOMEM) { GLOO_ENFORCE( mr_ != nullptr, "ibv_reg_mr: ", strerror(errno), " (did you run into the locked memory limit?)"); }
voidBuffer::waitRecv(){ // 如果该pair处于同步模式,则当前线程负责轮询工作完成情况。 // 由于单个pair可能为多个缓冲区提供服务,因此完成可能旨在用于另一个缓冲区。 auto timeout = pair_->getTimeout(); if (pair_->sync_) { auto start = std::chrono::steady_clock::now(); // We can assume a single pair is never used by more than one // thread, so there is no need to acquire the mutex here. while (recvCompletions_ == 0) { pair_->pollCompletions(); if (timeout != kNoTimeout && (std::chrono::steady_clock::now() - start) >= timeout) { pair_->signalIoFailure( GLOO_ERROR_MSG("Read timeout ", pair_->peer().str())); GLOO_ENFORCE(false, "Unexpected code path"); } } recvCompletions_--; } else { // The device thread will signal completion. If the completion // hasn't arrived yet, wait until it does. auto pred = [&]{ checkErrorState(); return recvCompletions_ > 0; }; std::unique_lock<std::mutex> lock(m_); if (timeout == kNoTimeout) { // No timeout set. Wait for read to complete. recvCv_.wait(lock, pred); } else { auto done = recvCv_.wait_for(lock, timeout, pred); if (!done) { // Release the mutex before calling into the pair to avoid deadlock. // Calling signalIoFailure() will throw, so no need to // reacquire. lock.unlock(); pair_->signalIoFailure( GLOO_ERROR_MSG("Read timeout ", pair_->peer().str())); GLOO_ENFORCE(false, "Unexpected code path"); } } recvCompletions_--; } }
// Wait for the previous send operation to finish. voidBuffer::waitSend(){ // 如果该pair处于同步模式,则当前线程负责轮询工作完成情况。 auto timeout = pair_->getTimeout(); if (pair_->sync_) { // We can assume a single pair is never used by more than one // thread, so there is no need to acquire the mutex here. if (sendCompletions_ == 0) { GLOO_ENFORCE_GT(sendPending_, 0, "No send to wait for"); auto start = std::chrono::steady_clock::now(); // We can assume a single pair is never used by more than one // thread, so there is no need to acquire the mutex here. while (sendCompletions_ == 0) { pair_->pollCompletions(); if (timeout != kNoTimeout && (std::chrono::steady_clock::now() - start) >= timeout) { pair_->signalIoFailure( GLOO_ERROR_MSG("Send timeout ", pair_->peer().str())); GLOO_ENFORCE(false, "Unexpected code path"); } } } sendCompletions_--; } else { // The device thread will signal completion. If the completion // hasn't arrived yet, wait until it does. std::unique_lock<std::mutex> lock(m_); checkErrorState(); if (sendCompletions_ == 0) { GLOO_ENFORCE_GT(sendPending_, 0, "No send to wait for"); auto pred = [&]{ checkErrorState(); return sendCompletions_ > 0; }; if (timeout == kNoTimeout) { // No timeout set. Wait for read to complete. sendCv_.wait(lock, pred); } else { auto done = sendCv_.wait_for(lock, timeout, pred); if (!done) { // Release the mutex before calling into the pair to avoid deadlock. // Calling signalIoFailure() will throw, so no need to // reacquire. lock.unlock(); pair_->signalIoFailure( GLOO_ERROR_MSG("Send timeout ", pair_->peer().str())); GLOO_ENFORCE(false, "Unexpected code path"); } } } sendCompletions_--; } }
voidBuffer::send(size_t offset, size_t length, size_t roffset){ int rv;
// Can't assert on roffset, since we don't know the size of // the remote buffer. Refactor of initialization code needed // to support this. GLOO_ENFORCE_LE(offset + length, size_);
// As we don't need to handle legacy clients, // let's remove support for legacy renegotiation: _glootls::SSL_CTX_clear_options(ssl_ctx, SSL_OP_LEGACY_SERVER_CONNECT);
_glootls::SSL_CTX_set_verify_depth(ssl_ctx, 1);
// To enforcing a higher security level, set it to 3. // // 2级 // 安全级别设置为 112 位安全。 因此,禁止使用短于 2048 位的 RSA、DSA 和 DH 密钥以及短于 224 位的 ECC 密钥。 // 除了 1 级排除之外,还禁止使用任何使用 RC4 的密码套件。 SSL 版本 3 也是不允许的。 压缩被禁用。 // // Level 3 // 安全级别设置为 128 位安全。 // 因此,禁止使用小于 3072 位的 RSA、DSA 和 DHkey 以及小于 256 位的 ECC 密钥。 // 除了 2 级排除之外,禁止使用不提供前向保密的密码套件。 不允许使用低于 1.1 的 TLS 版本。 会话票证被禁用。 // // TODO: should be 3, but it doesn't work yet :( _glootls::SSL_CTX_set_security_level(ssl_ctx, 2);
// See if there is a remote pending send that can fulfill this recv. auto it = findPendingOperations(slot); if (it != pendingOperations_.end()) { auto& pendingOperation = *it;
// Out of all remote pending sends, find the first one // that exists in the set of eligible ranks. for (constauto rank : pendingOperation.getSendList()) { for (constauto srcRank : srcRanks) { if (rank == srcRank) { // 我们找到了一个可以满足这个recv的等级。 // 此函数的调用者将尝试进行recv,如果该远程挂起发送操作仍然存在,它将删除它。 // return rank; } } } }
// No candidates; register buffer for recv pendingRecv_[slot].emplace_back( buf->getWeakNonOwningPtr(), offset, nbytes, std::unordered_set<int>(srcRanks.begin(), srcRanks.end())); return-1; }
// Allowed to be called only by ContextMutator::findRecvFromAny, // where the context lock is already held. boolContext::findRecvFromAny( uint64_t slot, int rank, WeakNonOwningPtr<UnboundBuffer>* buf, size_t* offset, size_t* nbytes){ // See if there is a pending recv for this slot. auto pit = pendingRecv_.find(slot); if (pit != pendingRecv_.end()) { auto& recvs = pit->second;
// Iterate over available buffers to find a match. for (auto rit = recvs.begin(); rit != recvs.end(); rit++) { constauto& ranks = std::get<3>(*rit);
// Wait for loop to tick before returning, to make sure the handler // for this fd is not called once this function returns. if (std::this_thread::get_id() != loop_->get_id()) { std::unique_lock<std::mutex> lock(m_); cv_.wait(lock); TSAN_ANNOTATE_HAPPENS_AFTER(h); } }
voidLoop::run(){ std::array<struct epoll_event, capacity_> events; int nfds;
while (!done_) { // Wakeup everyone waiting for a loop tick to finish. cv_.notify_all();
// Wait for something to happen nfds = epoll_wait(fd_, events.data(), events.size(), 10); if (nfds == 0) { continue; } if (nfds == -1 && errno == EINTR) { continue; }
GLOO_ENFORCE_NE(nfds, -1);
for (int i = 0; i < nfds; i++) { Handler* h = reinterpret_cast<Handler*>(events[i].data.ptr); h->handleEvents(events[i].events); TSAN_ANNOTATE_HAPPENS_BEFORE(h); } } }
// Use weak pointer so that the initializer is destructed when the // last context referring to it is destructed, not when statics // are destructed on program termination. static std::weak_ptr<MPIScope> wptr; std::shared_ptr<MPIScope> sptr;
// Create MPIScope only once std::call_once(once, [&]() { sptr = std::make_shared<MPIScope>(); wptr = sptr; });
// Create shared_ptr<MPIScope> from weak_ptr sptr = wptr.lock(); GLOO_ENFORCE(sptr, "Cannot create MPI context after MPI_Finalize()"); return sptr; }
返回MPI上下文(通信域)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
std::shared_ptr<Context> Context::createManaged(){ auto mpiScope = getMPIScope(); auto context = std::make_shared<Context>(MPI_COMM_WORLD); context->mpiScope_ = std::move(mpiScope); return context; }
voidContext::connectFullMesh(std::shared_ptr<transport::Device>& dev){ std::vector<std::vector<char>> addresses(size); unsignedlong maxLength = 0; int rv;
// Create pair to connect to every other node in the collective auto transportContext = dev->createContext(rank, size); transportContext->setTimeout(getTimeout()); for (int i = 0; i < size; i++) { if (i == rank) { continue; }
auto& pair = transportContext->createPair(i);
// Store address for pair for this rank auto address = pair->address().bytes(); maxLength = std::max(maxLength, address.size()); addresses[i] = std::move(address); }
// Agree on maximum length so we can prepare buffers rv = MPI_Allreduce( MPI_IN_PLACE, &maxLength, 1, MPI_UNSIGNED_LONG, MPI_MAX, comm_); if (rv != MPI_SUCCESS) { GLOO_THROW_IO_EXCEPTION("MPI_Allreduce: ", rv); }
// Prepare input and output std::vector<char> in(size * maxLength); std::vector<char> out(size * size * maxLength); for (int i = 0; i < size; i++) { if (i == rank) { continue; }
// Type of reduction function. // 如果reduce类型是内置类型之一,则算法实现可以使用加速版本(如果可用)。 // 例如,如果将 ReductionType 等于 SUM 的 ReductionFunction 传递给 CUDA 感知的 Allreduce,它知道它可以使用 NCCL 实现而不是指定的函数。 // enumReductionType { SUM = 1, PRODUCT = 2, MAX = 3, MIN = 4,
// Use larger number so we have plenty of room to add built-ins CUSTOM = 1000, };
template <typename T> classReductionFunction { public: using Function = void(T*, const T*, size_t n);
// Local operation. // If an algorithm uses multiple local pointers, local operations // can be used for local reduction, broadcast, gathering, etc. template <typename T> classLocalOp { public: virtual ~LocalOp() noexcept(false) {} virtualvoidrunAsync()= 0; virtualvoidwait()= 0;
// Synchronous run is equal to asynchronous run and wait. inlinevoidrun(){ runAsync(); wait(); } };
allgather
AllgatherRing 类似于 MPI_Allgather,所有进程都从所有其他进程接收缓冲区(inPtrs)。 调用者需要传递一个预先分配的接收缓冲区 (outPtr),其大小等于[ 上下文大小 x 发送缓冲区的总大小] (inPtrs),其中 rank = k 的进程的发送缓冲区将被写入 outPtr[k * 输入缓冲区数量 * count] 连续。
// If the input buffer is specified, this is NOT an in place operation, // and the output buffer needs to be primed with the input. if (in != nullptr) { memcpy( static_cast<uint8_t*>(out->ptr) + context->rank * in->size, static_cast<uint8_t*>(in->ptr), in->size); }
// Short circuit if there is only a single process. if (context->size == 1) { return; }
// The chunk size may not be divisible by 2; use dynamic lookup. std::array<size_t, 2> chunkSize; chunkSize[0] = inBytes / 2; chunkSize[1] = inBytes - chunkSize[0]; std::array<size_t, 2> chunkOffset; chunkOffset[0] = 0; chunkOffset[1] = chunkSize[0];
// Wait for pending operations to complete to synchronize with the // previous iteration. Because we kick off two operations before // getting here we always wait for the next-to-last operation. out->waitSend(opts.timeout); out->waitRecv(opts.timeout); out->send(sendRank, slot, sendOffset, size); out->recv(recvRank, slot, recvOffset, size); }
// Wait for completes for (auto i = 0; i < 2; i++) { out->waitSend(opts.timeout); out->waitRecv(opts.timeout); } }
// 计算每个进程对应的长度和偏移 std::vector<size_t> byteCounts; std::vector<size_t> byteOffsets; byteCounts.reserve(context->size); byteOffsets.reserve(context->size); size_t offset = 0; for (constauto& elements : opts.elements) { constauto bytes = elements * opts.elementSize; byteCounts.push_back(bytes); byteOffsets.push_back(offset); offset += bytes; }
// 如果指定了输入缓冲区,则需要准备输出缓冲区。 if (in != nullptr) { GLOO_ENFORCE_EQ(byteCounts[context->rank], in->size); if (byteCounts[context->rank] > 0) { memcpy( static_cast<uint8_t*>(out->ptr) + byteOffsets[context->rank], static_cast<uint8_t*>(in->ptr), in->size); } }
// Short circuit if there is only a single process. if (context->size == 1) { return; }
constauto baseIndex = context->size + context->rank; for (auto i = 0; i < context->size - 1; i++) { constsize_t sendIndex = (baseIndex - i) % context->size; constsize_t recvIndex = (baseIndex - i - 1) % context->size;
if (i == 0) { out->send(sendRank, slot, byteOffsets[sendIndex], byteCounts[sendIndex]); out->recv(recvRank, slot, byteOffsets[recvIndex], byteCounts[recvIndex]); continue; }
// Wait for previous operations to complete before kicking off new ones. out->waitSend(opts.timeout); out->waitRecv(opts.timeout); out->send(sendRank, slot, byteOffsets[sendIndex], byteCounts[sendIndex]); out->recv(recvRank, slot, byteOffsets[recvIndex], byteCounts[recvIndex]); }
// Wait for final operations to complete. out->waitSend(opts.timeout); out->waitRecv(opts.timeout); }
using BufferVector = std::vector<std::unique_ptr<transport::UnboundBuffer>>; using ReductionFunction = AllreduceOptions::Func; using ReduceRangeFunction = std::function<void(size_t, size_t)>; using BroadcastRangeFunction = std::function<void(size_t, size_t)>;
// Forward declaration of ring algorithm implementation. voidring( const detail::AllreduceOptionsImpl& opts, ReduceRangeFunction reduceInputs, BroadcastRangeFunction broadcastOutputs);
// ReductionFunction type describes the function to use for element wise reduction. // // Its arguments are: // 1. non-const output pointer // 2. const input pointer 1 (may be equal to 1) // 3. const input pointer 2 (may be equal to 1) // 4. number of elements to reduce. // // 请注意,此函数不是严格类型的,并且采用 void 指针。 // 这样做是为了避免需要模板化选项类和模板化算法实现。 // 我们发现这对编译时间和代码大小的增加几乎没有任何价值。s
// If the segment is entirely in range, the following statement is // equal to segmentBytes. If it isn't, it will be less, or even // negative. This is why the ssize_t typecasts are needed. result.sendLength = std::min( (ssize_t)segmentBytes, (ssize_t)totalBytes - (ssize_t)result.sendOffset); result.recvLength = std::min( (ssize_t)segmentBytes, (ssize_t)totalBytes - (ssize_t)result.recvOffset);
return result; };
// Ring reduce/scatter. // // 迭代次数计算如下: // - 使用 `numSegments` 作为段的总数, // - 减去 `numSegmentsPerRank`,因为最终段包含部分结果,在此阶段不得转发。 // - 添加 2,因为我们通过管道发送和接收操作(我们在迭代 0 和 1 上发出发送/接收操作并等待它们在迭代 2 和 3 上完成)。 // for (auto i = 0; i < (numSegments - numSegmentsPerRank + 2); i++) { if (i >= 2) { // 计算两次迭代前的发送和接收偏移量和长度。 // 需要这样我们知道何时等待操作以及何时忽略(当偏移量超出范围时),并知道在哪里减少临时缓冲区的内容。 auto prev = computeReduceScatterOffsets(i - 2); if (prev.recvLength > 0) { // Prepare out[0]->ptr to hold the local reduction reduceInputs(prev.recvOffset, prev.recvLength); // Wait for segment from neighbor. tmp->waitRecv(opts.timeout); // 对收到的段进行reduce opts.reduce( static_cast<uint8_t*>(out[0]->ptr) + prev.recvOffset, static_cast<constuint8_t*>(out[0]->ptr) + prev.recvOffset, static_cast<constuint8_t*>(tmp->ptr) + segmentOffset[i & 0x1], prev.recvLength / opts.elementSize); } if (prev.sendLength > 0) { out[0]->waitSend(opts.timeout); } }
// 在最后两次迭代之外的所有迭代中发出新的发送和接收操作。 // 那时我们已经发送了我们需要的所有数据,只需要等待最终的段被reduce到输出中。 if (i < (numSegments - numSegmentsPerRank)) { // Compute send and receive offsets and lengths for this iteration. auto cur = computeReduceScatterOffsets(i); if (cur.recvLength > 0) { tmp->recv(recvRank, slot, segmentOffset[i & 0x1], cur.recvLength); } if (cur.sendLength > 0) { // Prepare out[0]->ptr to hold the local reduction for this segment if (i < numSegmentsPerRank) { reduceInputs(cur.sendOffset, cur.sendLength); } out[0]->send(sendRank, slot, cur.sendOffset, cur.sendLength); } } }
// Function computes the offsets and lengths of the segments to be // sent and received for a given iteration during allgather. auto computeAllgatherOffsets = [&](size_t i) { struct { size_t sendOffset; size_t recvOffset; ssize_t sendLength; ssize_t recvLength; } result;
// If the segment is entirely in range, the following statement is // equal to segmentBytes. If it isn't, it will be less, or even // negative. This is why the ssize_t typecasts are needed. result.sendLength = std::min( (ssize_t)segmentBytes, (ssize_t)totalBytes - (ssize_t)result.sendOffset); result.recvLength = std::min( (ssize_t)segmentBytes, (ssize_t)totalBytes - (ssize_t)result.recvOffset);
return result; };
// Ring allgather. // // 注意:totalBytes <= (numSegments * segmentBytes), // 这与在进程间贡献相同的通用 allgather 算法不兼容。 // for (auto i = 0; i < (numSegments - numSegmentsPerRank + 2); i++) { if (i >= 2) { auto prev = computeAllgatherOffsets(i - 2); if (prev.recvLength > 0) { out[0]->waitRecv(opts.timeout); // Broadcast received segments to output buffers. broadcastOutputs(prev.recvOffset, prev.recvLength); } if (prev.sendLength > 0) { out[0]->waitSend(opts.timeout); } }
// 在最后两次迭代之外的所有迭代中发出新的发送和接收操作。 // 那时我们已经发送了我们需要的所有数据,只需要等待最终的段被发送到输出。 if (i < (numSegments - numSegmentsPerRank)) { auto cur = computeAllgatherOffsets(i); if (cur.recvLength > 0) { out[0]->recv(recvRank, slot, cur.recvOffset, cur.recvLength); } if (cur.sendLength > 0) { out[0]->send(sendRank, slot, cur.sendOffset, cur.sendLength); // Broadcast first segments to outputs buffers. if (i < numSegmentsPerRank) { broadcastOutputs(cur.sendOffset, cur.sendLength); } } } } }
structgroup { // Distance between peers in this group. size_t peerDistance;
// Segment that this group is responsible for reducing. size_t bufferOffset; size_t bufferLength;
// The process ranks that are a member of this group. std::vector<size_t> ranks;
// Upper bound of the length of the chunk that each process has the // reduced values for by the end of the reduction for this group. size_t chunkLength;
// Chunk within the segment that this process is responsible for reducing. size_t myChunkOffset; size_t myChunkLength; };
// Wait for send and receive operations to complete. for (size_t i = 0; i < group.ranks.size(); i++) { constauto peer = group.ranks[i]; if (peer == context->rank) { continue; } tmp->waitRecv(); out->waitSend(); }
// Allgather. for (auto it = groups.rbegin(); it != groups.rend(); it++) { constauto& group = *it;
// Issue receive operations for reduced chunks from peers. for (size_t i = 0; i < group.ranks.size(); i++) { constauto src = group.ranks[i]; if (src == context->rank) { continue; } constsize_t currentChunkOffset = group.bufferOffset + i * group.chunkLength; constsize_t currentChunkLength = std::min( size_t(group.chunkLength), size_t(std::max( int64_t(0), int64_t(group.bufferLength) - int64_t(i * group.chunkLength)))); out->recv( src, slot, currentChunkOffset * elementSize, currentChunkLength * elementSize); }
// Issue send operations for reduced chunk to peers. for (size_t i = 0; i < group.ranks.size(); i++) { constauto dst = group.ranks[i]; if (dst == context->rank) { continue; } out->send( dst, slot, group.myChunkOffset * elementSize, group.myChunkLength * elementSize); }
// Wait for operations to complete. for (size_t i = 0; i < group.ranks.size(); i++) { constauto peer = group.ranks[i]; if (peer == context->rank) { continue; } out->waitRecv(); out->waitSend(); }
// Broadcast result to multiple output buffers, if applicable. for (size_t i = 0; i < group.ranks.size(); i++) { constauto peer = group.ranks[i]; if (peer == context->rank) { continue; } constsize_t currentChunkOffset = group.bufferOffset + i * group.chunkLength; constsize_t currentChunkLength = std::min( size_t(group.chunkLength), size_t(std::max( int64_t(0), int64_t(group.bufferLength) - int64_t(i * group.chunkLength)))); broadcastOutputs( currentChunkOffset * elementSize, currentChunkLength * elementSize); } } }
// Below implements a dissemination barrier, described in "Two algorithms // for barrier synchronization (1988)" by Hensgen, Finkel and Manber. // PDF: https://www.inf.ed.ac.uk/teaching/courses/ppls/BarrierPaper.pdf // DOI: 10.1007/BF01379320
// Instead of iterating over i up to log2(context->size), we immediately // compute 2^i and compare with context->size. for (size_t d = 1; d < context->size; d <<= 1) { buffer->recv((context->size + context->rank - d) % context->size, slot); buffer->send((context->size + context->rank + d) % context->size, slot); buffer->waitRecv(opts.timeout); buffer->waitSend(opts.timeout); } }