Inter-node ReduceScatter
In this tutorial, you will write a multi-node reduce-scatter operation.
# To run this tutorial
source ./scripts/setenv.sh
bash ./scripts/launch.sh tutorials/06-inter-node-reduce-scatter.py
Kernel
@triton_dist.jit
def kernel_inter_node_p2p_for_same_local_rank(offset, local_world_size, M_per_rank, N, input, # [M, N]
output, # [M, N]
):
"""
This kernel performs P2P communication, sending corresponding data
to the GPU with the same local rank on node `cur_node_id + offset + 1`.
"""
rank = dl.rank()
world_size = dl.num_ranks()
node_id = rank // local_world_size
nnodes = world_size // local_world_size
local_rank = rank % local_world_size
nelem_per_rank = M_per_rank * N
remote_node_id = (offset + 1 + node_id) % nnodes
remote_rank = local_rank + remote_node_id * local_world_size
elem_size = tl.constexpr(input.dtype.element_ty.primitive_bitwidth) // 8
libshmem_device.putmem_block(
output + node_id * nelem_per_rank,
input + remote_node_id * nelem_per_rank,
nelem_per_rank * elem_size,
remote_rank,
)
def intra_node_scatter(input_intra_node, scatter_bufs_intra_node: List[torch.Tensor],
scatter_signal_buf_intra_node: torch.Tensor, local_rank, stream, overlap_with_gemm=True):
M, N = input_intra_node.shape
local_world_size = len(scatter_bufs_intra_node)
M_per_rank = M // local_world_size
# send input_intra_node[remote_rank * M_per_rank : (remote_rank + 1) * M_per_rank] on the current rank to
# input_intra_node[rank * M_per_rank : (rank + 1)] on the remote rank.
with torch.cuda.stream(stream):
for i in range(0, local_world_size):
# Rank level swizzle: Each rank perform scatter start from the next rank of the current.
# In this way, the send/recv communication volume of each rank is balanced.
remote_local_rank = (local_rank + i + 1) % local_world_size
if overlap_with_gemm:
# wait for the corresponding barrier to be set to 1, indicate the corresponding GEMM tile computation is complete.
wait_eq(scatter_signal_buf_intra_node[remote_local_rank].data_ptr(), 1, # signal
stream, True)
remote_buf = scatter_bufs_intra_node[remote_local_rank][local_rank * M_per_rank:(local_rank + 1) *
M_per_rank, :]
local_buf = input_intra_node[remote_local_rank * M_per_rank:(remote_local_rank + 1) * M_per_rank, :]
# use copy engine to perform scatter(torch will use `cudamemcpy` to copy continuous data)
remote_buf.copy_(local_buf)
We perform node-level swizzle in the reduce-scatter kernel. Each node perform intra-node reduce-scatter start from the next node of the current. In this way, the P2P communication volume of each node is balanced, and the last intra-node reduce-scatter operation is performed on the current node without the need for inter-node communication.
For time start from 0 to nnode - 1, the communication order between nodes:
time 0: 0->1, 1->2, 2->3, 3->0
time 1: 0->2, 1->3, 2->0, 3->1
time 2: 0->3, 1->0, 2->1, 3->2
time 3: 0->0, 1->1, 2->2, 3->3
def reducer_scatter_for_each_node(input, stream, ctx: ReduceScatter2DContext):
world_size = ctx.world_size
local_world_size = ctx.local_world_size
local_rank = ctx.local_rank
reduction_stream = ctx.reduction_stream
num_reduction_sms = ctx.num_reduction_sms
M, N = input.shape
M_per_rank = M // world_size
M_per_node = M_per_rank * local_world_size
nnodes = ctx.nnodes
node_id = ctx.node_id
rs_per_node_buf = ctx.symm_rs_per_node_buf
p2p_buf = ctx.symm_p2p_buf
# For `target_node_id` in the range [0, nnodes - 1], perform an intra-node reduce-scatter on the
# input[target_node_id * M_per_node: (target_node_id + 1) * M_per_node]. Then send the result via
# P2P to the same local rank on the node 'target_node_id'.
with torch.cuda.stream(stream):
for n in range(0, nnodes):
cur_node_id = (node_id + n + 1) % nnodes
input_intra_node = input[cur_node_id * M_per_node:(cur_node_id + 1) * M_per_node]
scatter_bufs_intra_node, scatter_signal_buf_intra_node = ctx.get_scatter_bufs_and_signal_for_each_node(
input, cur_node_id)
# step1: intra node reduce-scatter
# step1-1: intra node scatter, the corresponding data has been computed by GEMM.
intra_node_scatter(input_intra_node, scatter_bufs_intra_node, scatter_signal_buf_intra_node, local_rank,
stream, overlap_with_gemm=ctx.overlap_with_gemm)
rs_buf_cur_node = rs_per_node_buf[M_per_rank * cur_node_id:(cur_node_id + 1) * M_per_rank]
# step1-2: perform barrier_all, wait for all peers within the node to complete the scatter operation
barrier_all_on_stream(ctx.barrier, stream)
reduction_stream.wait_stream(stream)
with torch.cuda.stream(reduction_stream):
# step1-3: perform reduction operation to get the result of the intra-node reduce-scatter.
ring_reduce(scatter_bufs_intra_node[local_rank], rs_buf_cur_node, local_rank, local_world_size,
num_sms=-1 if n == nnodes - 1 else num_reduction_sms)
# step2: inter node p2p, send result to the same local rank on the node `(n + 1 + node_id) % nnodes`.
if nnodes > 1:
if n == nnodes - 1:
p2p_buf[M_per_rank * node_id:M_per_rank * (node_id + 1)].copy_(
rs_per_node_buf[M_per_rank * node_id:M_per_rank * (node_id + 1)])
else:
grid = lambda META: (ctx.num_p2p_sms, )
kernel_inter_node_p2p_for_same_local_rank[grid](
n,
local_world_size,
M_per_rank,
N,
rs_per_node_buf,
p2p_buf,
num_warps=16,
)
stream.wait_stream(reduction_stream)
if nnodes == 1:
return rs_per_node_buf[:M_per_rank * nnodes]
return p2p_buf[:M_per_rank * nnodes]
A hierarchical reduce-scatter implementation that overlaps the intra-node scatter with the local reduce and the inter-node p2p(after reduce). It also provides a rank-wise signal and supports overlap with gemm.
def reduce_scatter_multi_node(input, stream, ctx: ReduceScatter2DContext):
M, N = input.shape
M_per_rank = M // ctx.world_size
ctx.p2p_stream.wait_stream(stream)
"""
Step 1: Leveraging the characteristics of reduce-scatter, we first partition the input data according to the target nodes for communication.
For the data send to each node, we perform an intra-node reduce-scatter operation within the current node.
Finally, we use P2P communication to send the data to the same local rank on the target node.
This can reduce the inter-node communication volume by a factor of local_world_size.
"""
rs_resutl_per_node = reducer_scatter_for_each_node(input, stream, ctx)
nvshmem_barrier_all_on_stream(stream)
output = torch.empty((M_per_rank, N), dtype=input.dtype, device=input.device)
"""
Step 2: After receiving data sent via P2P from all nodes, perform a reduction to get the final result.
"""
with torch.cuda.stream(stream):
ring_reduce(rs_resutl_per_node, output, ctx.node_id, ctx.nnodes)
return output
def reduce_scatter_2d_op(input, ctx: ReduceScatter2DContext):
# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
reduction_stream = ctx.reduction_stream
M, N = input.shape
assert input.dtype == ctx.dtype
assert ctx.max_M >= M and ctx.N == N
assert M % ctx.world_size == 0
current_stream = torch.cuda.current_stream()
reduction_stream.wait_stream(current_stream)
# Wait for the completion of the previous iteration.
nvshmem_barrier_all_on_stream(current_stream)
# perform reduce-scatter
output = reduce_scatter_multi_node(input, current_stream, ctx)
# Reset the barriers for the next iteration.
ctx.reset_barriers()
return output
Benchmark
def torch_rs(
input: torch.Tensor, # [M, N]
TP_GROUP,
):
M, N = input.shape
rs_output = torch.empty((M // WORLD_SIZE, N), dtype=input.dtype, device=input.device)
torch.distributed.reduce_scatter_tensor(rs_output, input, group=TP_GROUP)
return rs_output
if __name__ == "__main__":
# init
RANK = int(os.environ.get("RANK", 0))
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
TP_GROUP = initialize_distributed()
torch.cuda.synchronize()
output_dtype = torch.bfloat16
M, N = 8192, 16384
rs_ctx = create_reduce_scater_2d_ctx(M, N, RANK, WORLD_SIZE, LOCAL_WORLD_SIZE, output_dtype,
overlap_with_gemm=False)
# gen input
input = torch.rand((M, N), dtype=output_dtype).cuda()
# torch impl
torch_output = torch_rs(input, TP_GROUP)
nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
# dist triton impl
dist_triton_output = reduce_scatter_2d_op(input, rs_ctx)
nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
torch.cuda.synchronize()
# check
atol, rtol = 6e-2, 6e-2
torch.testing.assert_close(torch_output, dist_triton_output, atol=atol, rtol=rtol)
torch.cuda.synchronize()
print(f"RANK {RANK}: pass!")
rs_ctx.finalize()
nvshmem.core.finalize()
torch.distributed.destroy_process_group()