Intra-node AllGather
In this tutorial, you will write a distributed AllGather kernel using Triton-distributed.
In doing so, you will learn about:
Writing the AllGather kernel with symmetric pointers directly.
Writing the AllGather kernel with NVSHMEM device functions.
# To run this tutorial
source ./scripts/setenv.sh
bash ./scripts/launch.sh tutorials/02-intra-node-allgather.py
Kernel
There are several communication methods available for intra-node communication: it can be achieved directly through Memory Copy interfaces (utilizing the copy engine), via kernel ld/st operations (using SMs), or with NVSHMEM primitives (also using SMs). We recommend using either the Memory Copy interface or NVSHMEM primitives. In terms of performance, both methods can achieve comparable results. The key difference is that Memory Copy does not occupy SM resources.
Let’s introduce the Memory Copy interface first. The only difference from a regular PyTorch program is the remote_tensor_buffers parameter. This parameter is a list of Tensors, where each element corresponds to the Tensor at the respective rank position. This parameter is obtained through NVSHMEM’s host interface:
from triton_dist.utils import nvshmem_create_tensors
# Create symmetric tensors accessible from all ranks
symm_ag_buffers = nvshmem_create_tensors((M, N), dtype, rank, LOCAL_WORLD_SIZE)
import os
from typing import List
import nvshmem.core
import torch
from triton_dist.utils import cuda
import triton_dist
import triton.language as tl
from triton_dist.language.extra import libshmem_device
from triton_dist.utils import (
CUDA_CHECK,
dist_print,
initialize_distributed,
nvshmem_barrier_all_on_stream,
NVSHMEM_SIGNAL_DTYPE,
nvshmem_create_tensors,
nvshmem_free_tensor_sync,
)
def cp_engine_producer_all_gather_full_mesh_pull(
rank,
num_ranks,
local_tensor: torch.Tensor,
remote_tensor_buffers: List[torch.Tensor],
ag_stream: torch.cuda.Stream,
barrier_buffers: List[torch.Tensor],
):
M_per_rank, N = local_tensor.shape
rank_orders = [(rank + i) % num_ranks for i in range(num_ranks)]
with torch.cuda.stream(ag_stream):
for src_rank in rank_orders:
if src_rank == rank:
continue
# peer: src_rank, offset src_rank[src_rank] -> rank[src_rank]
dst = remote_tensor_buffers[rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :]
src = remote_tensor_buffers[src_rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :]
dst.copy_(src)
(err, ) = cuda.cuStreamWriteValue32(
ag_stream.cuda_stream,
barrier_buffers[rank][src_rank].data_ptr(),
1,
cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
)
CUDA_CHECK(err)
For kernels with NVSHMEM primitives, we use the triton_dist.jit decorator. In this example, we use 8 SMs (DISPATCH_BLOCK_NUM = 8) to perform AllGather. Each block is responsible for sending its local data to the other 7 GPUs.
@triton_dist.jit
def nvshmem_device_producer_all_gather_2d_put_block_kernel(
remote_tensor_ptr,
signal_buffer_ptr,
elem_per_rank,
size_per_elem,
signal_target,
local_rank,
world_size,
DISPATCH_BLOCK_NUM: tl.constexpr,
):
pid = tl.program_id(axis=0)
if pid < DISPATCH_BLOCK_NUM: # intra dispatch block
peer = (local_rank + pid + 1) % world_size
segment = local_rank
libshmem_device.putmem_signal_block( # send the segment to the peer and notify the segment is ready
remote_tensor_ptr + segment * elem_per_rank,
remote_tensor_ptr + segment * elem_per_rank,
elem_per_rank * size_per_elem,
signal_buffer_ptr + segment,
signal_target,
libshmem_device.NVSHMEM_SIGNAL_SET,
peer,
)
Test the Correctness
if __name__ == "__main__":
TP_GROUP = initialize_distributed()
rank = TP_GROUP.rank()
num_ranks = TP_GROUP.size()
LOCAL_WORLD_SIZE = int(os.getenv("LOCAL_WORLD_SIZE"))
assert num_ranks == LOCAL_WORLD_SIZE, "This tutorial is designed for intra-node"
M = 8192
N = 12288
M_per_rank = M // num_ranks
dtype = torch.float16
local_data = torch.randn([M_per_rank, N], dtype=dtype, device="cuda")
# Create symmetric tensors using the new API
symm_ag_buffers = nvshmem_create_tensors((M, N), dtype, rank, LOCAL_WORLD_SIZE)
symm_ag_buffer = symm_ag_buffers[rank]
symm_signals = nvshmem_create_tensors((num_ranks, ), NVSHMEM_SIGNAL_DTYPE, rank, LOCAL_WORLD_SIZE)
symm_signal = symm_signals[rank]
# Calculate golden
golden = torch.empty([M, N], dtype=dtype, device="cuda")
torch.distributed.all_gather_into_tensor(golden, local_data, group=TP_GROUP)
#####################
# Copy Engine
symm_ag_buffer.fill_(-1) # reset buffer
symm_ag_buffer[
rank * M_per_rank:(rank + 1) * M_per_rank,
].copy_(local_data) # copy local data to symmetric memory for communication
symm_signal.fill_(0) # The initial value of signal should be 0s
# We need barrier all to make sure the above initialization visible to other ranks
nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
cp_engine_producer_all_gather_full_mesh_pull(
rank, num_ranks, local_data, symm_ag_buffers, torch.cuda.current_stream(),
symm_signals) # Here we use current stream for allgather, we can pass any other stream for comm-comp fusion.
# Check results. Pull mode doesn't need sync after communication
dist_print(f"Rank {rank} CpEngine Result:\n", symm_ag_buffer, need_sync=True, allowed_ranks="all")
dist_print(f"Rank {rank} CpEngine Signal:\n", symm_signal, need_sync=True, allowed_ranks="all")
assert torch.allclose(golden, symm_ag_buffer, atol=1e-5, rtol=1e-5)
dist_print(f"Rank {rank}", "Pass!✅", need_sync=True, allowed_ranks="all")
#####################
# NVSHMEM Primitives
symm_ag_buffer.fill_(-1) # reset buffer
symm_ag_buffer[
rank * M_per_rank:(rank + 1) * M_per_rank,
].copy_(local_data) # copy local data to symmetric memory for communication
symm_signal.fill_(0) # The initial value of signal should be 0s
# We need barrier all to make sure the above initialization visible to other ranks
nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
grid = lambda META: (int(num_ranks), )
nvshmem_device_producer_all_gather_2d_put_block_kernel[grid](
symm_ag_buffer, symm_signal, M_per_rank * N, # No. of elems of local data
local_data.element_size(), # element size
1, # signal target, can be any other value in practice
rank, num_ranks, num_ranks)
# Need to sync all to guarantee the completion of communication
nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
# Check results. Pull mode doesn't need sync after communication
dist_print(f"Rank {rank} NVSHMEM Result:\n", symm_ag_buffer, need_sync=True, allowed_ranks="all")
dist_print(f"Rank {rank} NVSHMEM Signal:\n", symm_signal, need_sync=True, allowed_ranks="all")
assert torch.allclose(golden, symm_ag_buffer, atol=1e-5, rtol=1e-5)
dist_print(f"Rank {rank}", "Pass!✅", need_sync=True, allowed_ranks="all")
# Clean up symmetric memory
nvshmem_free_tensor_sync(symm_ag_buffer)
nvshmem_free_tensor_sync(symm_signal)
nvshmem.core.finalize()
torch.distributed.destroy_process_group()