Skip to content

Commit

Permalink
ncclBroadcast using scratch buffer and option to use an executor on t…
Browse files Browse the repository at this point in the history
…op of the main branch.
  • Loading branch information
SreevatsaAnantharamu committed Dec 18, 2024
1 parent fcb2e46 commit 64de849
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 6 deletions.
171 changes: 171 additions & 0 deletions apps/nccl/src/broadcast.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#ifndef BROADCAST_HPP_
#define BROADCAST_HPP_

#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>

#include "common.hpp"

template <bool IsOutOfPlace>
__global__ void __launch_bounds__(1024, 1)
broadcast6(void* sendbuff, void* scratchbuff, void* recvbuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
size_t channelOutOffset, size_t rank, [[maybe_unused]] size_t worldSize, size_t root,
size_t nRanksPerNode, size_t nelemsPerGPU) {
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
const size_t lid = tid % WARP_SIZE;
const size_t wid = tid / WARP_SIZE;

const size_t nThread = blockDim.x * gridDim.x;
const size_t nWarp = nThread / WARP_SIZE;
const size_t nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> smChans[NRANKS_PER_NODE - 1];
if (threadIdx.x < nPeer) {
smChans[threadIdx.x] = smChannels[chanOffset + threadIdx.x];
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].wait();
}
__syncthreads();

const size_t peerRootIdx = (root == rank) ? nPeer : ((root < rank) ? root : (root - 1));

const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
const size_t bytes = bytesPerGPU;
size_t unitBytesPerThread;
if (bytes * nPeer >= nThread * 64) {
unitBytesPerThread = 64;
} else {
unitBytesPerThread = 16;
}
const size_t unitBytesPerBlock = unitBytesPerThread * blockDim.x;
const size_t unitBytes = unitBytesPerBlock * gridDim.x;
const size_t nLoop = bytes / unitBytes;

const size_t maxScratchSizeToUse = (SCRATCH_SIZE - unitBytes);
const size_t nLoopToSync = (maxScratchSizeToUse / unitBytes) + 1;

size_t scratchSub = 0;

// First loop will always fit the scratch size.
if (nLoop > 0) {
// First loop unrolling
const size_t offset = blockIdx.x * unitBytesPerBlock;
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}

} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
}
}

for (size_t i = 1; i < nLoop; ++i) {
const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes;
if (i % nLoopToSync == 0) { // Sync to reuse scratch buff
scratchSub = -i * unitBytes;
deviceSyncer.sync(gridDim.x);
if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].wait();
}
}
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock,
threadIdx.x, blockDim.x);
}
}

// Remainder loop will also fit the scratch buff since we subtract unitBytes from SCRATCH_SIZE.
if (bytes % unitBytes > 0) { // remainder.
const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes = (offset < bytes) ? (bytes - offset) : 0;
if (remainBytes > 0) {
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x,
blockDim.x);
}
} // remainBytes > 0.
}

deviceSyncer.sync(gridDim.x);

if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].wait();
}
}

template <bool IsOutOfPlace, typename T>
cudaError_t broadcast(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
size_t channelOutOffset, int rank, int nRanksPerNode, int root, int worldSize, size_t nelems,
cudaStream_t stream) {
int nBlocks = 7;
// if (nelems <= 4096) {
// nBlocks = 7;
// } else if (nelems <= 32768) {
// nBlocks = 14;
// } else if (nelems >= 2097152) {
// nBlocks = 35;
// }
broadcast6<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, (void*)scratch, (void*)resultBuff, smChannels,
channelOutOffset, rank, worldSize, root, nRanksPerNode,
nelems * sizeof(T) / sizeof(int));
return cudaGetLastError();
}

#endif // BROADCAST_HPP_
99 changes: 93 additions & 6 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "allgather.hpp"
#include "allreduce.hpp"
#include "broadcast.hpp"
#include "nccl.h"

#define NCCL_API extern "C" __attribute__((visibility("default")))
Expand Down Expand Up @@ -514,14 +515,100 @@ NCCL_API ncclResult_t ncclReduce(const void*, void*, size_t, ncclDataType_t, ncc
return ncclInternalError;
}

NCCL_API ncclResult_t ncclBcast(void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) {
// TODO: implement this function
return ncclInternalError;
NCCL_API ncclResult_t ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root, ncclComm_t comm,
cudaStream_t stream) {
return ncclBroadcast(buff, buff, count, datatype, root, comm, stream);
}

NCCL_API ncclResult_t ncclBroadcast(const void*, void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) {
// TODO: implement this function
return ncclInternalError;
NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, int root, ncclComm_t comm, cudaStream_t stream) {
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

// Declarating variables
size_t recvBytes;
CUdeviceptr recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
// size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
size_t offsetOut = 0;
// channelKey recvKey{(void*)recvBasePtr, recvBytes};
channelKey recvKey{(void*)0x0, 0}; // Just create the channel once.
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;

auto it = comm->channelOutInfos.find(recvKey);
if (it == comm->channelOutInfos.end()) {
// std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
// comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
// std::vector<mscclpp::SmChannel> channels =
// setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)recvBasePtr));
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = it->second.smChannelDeviceHandles.get();
if ((char*)sendbuff == (char*)recvbuff) {
CUDACHECK(broadcast<false>((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, offsetOut,
rank, NRANKS_PER_NODE, root, nRank, bytes / sizeof(int), stream));
} else {
CUDACHECK(broadcast<true>((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, offsetOut,
rank, NRANKS_PER_NODE, root, nRank, bytes / sizeof(int), stream));
}

return ncclSuccess;
}

NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
int root, ncclComm_t comm, cudaStream_t stream) {
size_t bytes = count * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();

std::vector<executionPlanInstance>& plans = comm->executionPlans["broadcast"];
std::shared_ptr<mscclpp::ExecutionPlan> plan;
void* basePtr = (char*)sendbuff;
bool inPlace = basePtr == recvbuff;
const size_t totalBytes = bytes;
for (const auto& p : plans) {
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
}

if (plan == nullptr) return ncclBroadcastFallback(sendbuff, recvbuff, count, datatype, root, comm, stream);

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan,
stream);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32, *plan,
stream);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
mscclpp::DataType::BFLOAT16, *plan, stream);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan,
stream);
break;
default:
return ncclInvalidArgument;
}

return ncclSuccess;
}

NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
Expand Down

0 comments on commit 64de849

Please sign in to comment.