diff --git a/apps/nccl/src/broadcast.hpp b/apps/nccl/src/broadcast.hpp new file mode 100644 index 00000000..76899f93 --- /dev/null +++ b/apps/nccl/src/broadcast.hpp @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef BROADCAST_HPP_ +#define BROADCAST_HPP_ + +#include +#include +#include +#include +#include + +#include "common.hpp" + +template +__global__ void __launch_bounds__(1024, 1) + broadcast6(void* sendbuff, void* scratchbuff, void* recvbuff, mscclpp::DeviceHandle* smChannels, + size_t channelOutOffset, size_t rank, [[maybe_unused]] size_t worldSize, size_t root, + size_t nRanksPerNode, size_t nelemsPerGPU) { + const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; + const size_t lid = tid % WARP_SIZE; + const size_t wid = tid / WARP_SIZE; + + const size_t nThread = blockDim.x * gridDim.x; + const size_t nWarp = nThread / WARP_SIZE; + const size_t nPeer = nRanksPerNode - 1; + const size_t chanOffset = nPeer * blockIdx.x; + + __shared__ mscclpp::DeviceHandle smChans[NRANKS_PER_NODE - 1]; + if (threadIdx.x < nPeer) { + smChans[threadIdx.x] = smChannels[chanOffset + threadIdx.x]; + smChans[threadIdx.x].relaxedSignal(); + smChans[threadIdx.x].wait(); + } + __syncthreads(); + + const size_t peerRootIdx = (root == rank) ? nPeer : ((root < rank) ? root : (root - 1)); + + const size_t bytesPerGPU = nelemsPerGPU * sizeof(int); + const size_t bytes = bytesPerGPU; + size_t unitBytesPerThread; + if (bytes * nPeer >= nThread * 64) { + unitBytesPerThread = 64; + } else { + unitBytesPerThread = 16; + } + const size_t unitBytesPerBlock = unitBytesPerThread * blockDim.x; + const size_t unitBytes = unitBytesPerBlock * gridDim.x; + const size_t nLoop = bytes / unitBytes; + + const size_t maxScratchSizeToUse = (SCRATCH_SIZE - unitBytes); + const size_t nLoopToSync = (maxScratchSizeToUse / unitBytes) + 1; + + size_t scratchSub = 0; + + // First loop will always fit the scratch size. + if (nLoop > 0) { + // First loop unrolling + const size_t offset = blockIdx.x * unitBytesPerBlock; + if (rank == root) { + char* send_ = reinterpret_cast(sendbuff); + for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) { + char* dst = reinterpret_cast(smChans[peerIdx].dst_); // Peer's scratchbuff. + smChans[peerIdx].copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x); + __syncthreads(); + if (threadIdx.x == peerIdx) smChans[peerIdx].signal(); + } + if constexpr (IsOutOfPlace) { + char* recv_ = reinterpret_cast(recvbuff); + smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x); + } + + } else { // rank != root. + if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait(); + __syncthreads(); + char* recv_ = reinterpret_cast(recvbuff); + char* scratch_ = reinterpret_cast(scratchbuff); // My scratchbuff. + smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x, + blockDim.x); + } + } + + for (size_t i = 1; i < nLoop; ++i) { + const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes; + if (i % nLoopToSync == 0) { // Sync to reuse scratch buff + scratchSub = -i * unitBytes; + deviceSyncer.sync(gridDim.x); + if (threadIdx.x < nPeer) { + smChans[threadIdx.x].relaxedSignal(); + smChans[threadIdx.x].wait(); + } + } + if (rank == root) { + char* send_ = reinterpret_cast(sendbuff); + for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) { + char* dst = reinterpret_cast(smChans[peerIdx].dst_); // Peer's scratchbuff. + smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x, + blockDim.x); + __syncthreads(); + if (threadIdx.x == peerIdx) smChans[peerIdx].signal(); + } + if constexpr (IsOutOfPlace) { + char* recv_ = reinterpret_cast(recvbuff); + smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x); + } + } else { // rank != root. + if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait(); + __syncthreads(); + char* recv_ = reinterpret_cast(recvbuff); + char* scratch_ = reinterpret_cast(scratchbuff); // My scratchbuff. + smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock, + threadIdx.x, blockDim.x); + } + } + + // Remainder loop will also fit the scratch buff since we subtract unitBytes from SCRATCH_SIZE. + if (bytes % unitBytes > 0) { // remainder. + const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes; + const size_t remainBytes = (offset < bytes) ? (bytes - offset) : 0; + if (remainBytes > 0) { + if (rank == root) { + char* send_ = reinterpret_cast(sendbuff); + for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) { + char* dst = reinterpret_cast(smChans[peerIdx].dst_); // Peer's scratchbuff. + smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x, + blockDim.x); + __syncthreads(); + if (threadIdx.x == peerIdx) smChans[peerIdx].signal(); + } + if constexpr (IsOutOfPlace) { + char* recv_ = reinterpret_cast(recvbuff); + smChans[0].copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x); + } + } else { // rank != root. + if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait(); + __syncthreads(); + char* recv_ = reinterpret_cast(recvbuff); + char* scratch_ = reinterpret_cast(scratchbuff); // My scratchbuff. + smChans[peerRootIdx].copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x, + blockDim.x); + } + } // remainBytes > 0. + } + + deviceSyncer.sync(gridDim.x); + + if (threadIdx.x < nPeer) { + smChans[threadIdx.x].relaxedSignal(); + smChans[threadIdx.x].wait(); + } +} + +template +cudaError_t broadcast(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, + size_t channelOutOffset, int rank, int nRanksPerNode, int root, int worldSize, size_t nelems, + cudaStream_t stream) { + int nBlocks = 7; + // if (nelems <= 4096) { + // nBlocks = 7; + // } else if (nelems <= 32768) { + // nBlocks = 14; + // } else if (nelems >= 2097152) { + // nBlocks = 35; + // } + broadcast6<<>>((void*)buff, (void*)scratch, (void*)resultBuff, smChannels, + channelOutOffset, rank, worldSize, root, nRanksPerNode, + nelems * sizeof(T) / sizeof(int)); + return cudaGetLastError(); +} + +#endif // BROADCAST_HPP_ diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index fe240de7..06ab189d 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -15,6 +15,7 @@ #include "allgather.hpp" #include "allreduce.hpp" +#include "broadcast.hpp" #include "nccl.h" #define NCCL_API extern "C" __attribute__((visibility("default"))) @@ -514,14 +515,100 @@ NCCL_API ncclResult_t ncclReduce(const void*, void*, size_t, ncclDataType_t, ncc return ncclInternalError; } -NCCL_API ncclResult_t ncclBcast(void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) { - // TODO: implement this function - return ncclInternalError; +NCCL_API ncclResult_t ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root, ncclComm_t comm, + cudaStream_t stream) { + return ncclBroadcast(buff, buff, count, datatype, root, comm, stream); } -NCCL_API ncclResult_t ncclBroadcast(const void*, void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) { - // TODO: implement this function - return ncclInternalError; +NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff, size_t sendcount, + ncclDataType_t datatype, int root, ncclComm_t comm, cudaStream_t stream) { + size_t bytes = sendcount * ncclTypeSize(datatype); + if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument; + + // Declarating variables + size_t recvBytes; + CUdeviceptr recvBasePtr; + MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff)); + // size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; + size_t offsetOut = 0; + // channelKey recvKey{(void*)recvBasePtr, recvBytes}; + channelKey recvKey{(void*)0x0, 0}; // Just create the channel once. + int rank = comm->comm->bootstrap()->getRank(); + int nRank = comm->comm->bootstrap()->getNranks(); + mscclpp::DeviceHandle* smChannels = nullptr; + + auto it = comm->channelOutInfos.find(recvKey); + if (it == comm->channelOutInfos.end()) { + // std::vector remoteMemories = setupRemoteMemories( + // comm->comm, rank, const_cast((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc); + // std::vector channels = + // setupSmChannels(comm, remoteMemories, const_cast((void*)recvBasePtr)); + std::vector channels = + setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast((void*)recvBasePtr)); + std::vector> smChannelDeviceHandles; + std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), + [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); + ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)}; + it = comm->channelOutInfos.emplace(recvKey, channelInfo).first; + } + + smChannels = it->second.smChannelDeviceHandles.get(); + if ((char*)sendbuff == (char*)recvbuff) { + CUDACHECK(broadcast((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, offsetOut, + rank, NRANKS_PER_NODE, root, nRank, bytes / sizeof(int), stream)); + } else { + CUDACHECK(broadcast((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, offsetOut, + rank, NRANKS_PER_NODE, root, nRank, bytes / sizeof(int), stream)); + } + + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, + int root, ncclComm_t comm, cudaStream_t stream) { + size_t bytes = count * ncclTypeSize(datatype); + if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument; + + int rank = comm->comm->bootstrap()->getRank(); + int nRank = comm->comm->bootstrap()->getNranks(); + + std::vector& plans = comm->executionPlans["broadcast"]; + std::shared_ptr plan; + void* basePtr = (char*)sendbuff; + bool inPlace = basePtr == recvbuff; + const size_t totalBytes = bytes; + for (const auto& p : plans) { + if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) { + plan = p.plan; + break; + } + } + + if (plan == nullptr) return ncclBroadcastFallback(sendbuff, recvbuff, count, datatype, root, comm, stream); + + switch (datatype) { + case ncclFloat16: + comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan, + stream); + break; + case ncclFloat32: + comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32, *plan, + stream); + break; + case ncclBfloat16: + comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes, + mscclpp::DataType::BFLOAT16, *plan, stream); + break; + case ncclInt32: + case ncclUint32: + comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan, + stream); + break; + default: + return ncclInvalidArgument; + } + + return ncclSuccess; } NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,