Intra-node ReduceScatter
In this tutorial, you will write a intra-node reduce-scatter operation.
In doing so, you will learn about:
use copy engine to communicate data within a node
How to write an intra-node reduce-scatter kernel.
# To run this tutorial
source ./scripts/setenv.sh
bash ./scripts/launch.sh tutorials/05-intra-node-reduce-scatter.py
Reduce Kernel
The difference between this kernel and a normal reduction operation lies in the accumulation order. In a normal reduction, the accumulation starts from the 0th row. This kernel starts the accumulation from the offset.
@triton.jit
def kernel_ring_reduce(
c_ptr, # [M, N]
out_ptr, # [M_per_split, N]
# shape of matrix
M_per_rank,
N,
begin_idx,
num_splits: tl.constexpr,
# reduce tile shape
BLOCK_SIZE_M: tl.constexpr = 256,
BLOCK_SIZE_N: tl.constexpr = 64,
):
c_desc = tl.make_tensor_descriptor(
c_ptr,
shape=[M_per_rank * num_splits, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
output_desc = tl.make_tensor_descriptor(
out_ptr,
shape=[M_per_rank, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
pid = tl.program_id(axis=0)
num_pid = tl.num_programs(axis=0)
num_tiles_m = tl.cdiv(M_per_rank, BLOCK_SIZE_M)
num_tiles_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_tiles_m * num_tiles_n
for tile_id in range(pid, total_tiles, num_pid):
tile_id_m = tile_id // num_tiles_n
tile_id_n = tile_id % num_tiles_n
# accum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=out_ptr.dtype.element_ty)
cur_rank = (begin_idx + 1) % num_splits
accum = c_desc.load([tile_id_m * BLOCK_SIZE_M + cur_rank * M_per_rank, tile_id_n * BLOCK_SIZE_N])
for i in range(1, num_splits):
cur_rank = (i + begin_idx + 1) % num_splits
data = c_desc.load([tile_id_m * BLOCK_SIZE_M + cur_rank * M_per_rank, tile_id_n * BLOCK_SIZE_N])
accum += data
output_desc.store([tile_id_m * BLOCK_SIZE_M, tile_id_n * BLOCK_SIZE_N], accum)
def ring_reduce(
input, # [M_per_node, N]
output, # [M_per_rank, N]
begin_idx,
num_splits,
stream,
num_sms=-1,
):
# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
total_M, N = input.shape
M_per_split = total_M // num_splits
assert output.shape[0] == M_per_split and total_M % num_splits == 0
if num_sms == -1:
grid = lambda META: (triton.cdiv(M_per_split, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
with torch.cuda.stream(stream):
kernel_ring_reduce[grid](
input,
output,
M_per_split,
N,
begin_idx,
num_splits,
BLOCK_SIZE_M=256,
BLOCK_SIZE_N=64,
num_warps=4,
)
else:
grid = lambda META: (min(
triton.cdiv(M_per_split, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), num_sms), )
with torch.cuda.stream(stream):
kernel_ring_reduce[grid](
input,
output,
M_per_split,
N,
begin_idx,
num_splits,
BLOCK_SIZE_M=256,
BLOCK_SIZE_N=128,
num_warps=8,
)
return output
Scatter Kernel
We perform rank level swizzle in the scatter kernel. Each rank perform scatter start from the next rank of the current. In this way, the send/recv communication volume of each rank is balanced. For time start from 0 to local_world_size, the communication order between ranks:
time 0: 0->1, 1->2, 2->3, 3->0
time 1: 0->2, 1->3, 2->0, 3->1
time 2: 0->3, 1->0, 2->1, 3->2
time 3: 0->0, 1->1, 2->2, 3->3
def intra_node_scatter(input_intra_node, scatter_bufs_intra_node: List[torch.Tensor], local_rank, stream):
M, N = input_intra_node.shape
local_world_size = len(scatter_bufs_intra_node)
M_per_rank = M // local_world_size
# send input_intra_node[remote_rank * M_per_rank : (remote_rank + 1) * M_per_rank] on the current rank to
# input_intra_node[rank * M_per_rank : (rank + 1)] on the remote rank.
with torch.cuda.stream(stream):
for i in range(0, local_world_size):
remote_local_rank = (local_rank + i + 1) % local_world_size
remote_buf = scatter_bufs_intra_node[remote_local_rank][local_rank * M_per_rank:(local_rank + 1) *
M_per_rank, :]
local_buf = input_intra_node[remote_local_rank * M_per_rank:(remote_local_rank + 1) * M_per_rank, :]
# use copy engine to perform scatter(torch will use `cudamemcpy` to copy continuous data)
remote_buf.copy_(local_buf)
Reduce-Scatter
@p2p_native_atomic_required
def reducer_scatter_intra_node(input, scatter_bufs, sync_buf, local_rank, local_world_size):
stream = torch.cuda.current_stream()
M, N = input.shape
M_per_rank = M // local_world_size
output = torch.empty((M_per_rank, N), dtype=dtype, device=input.device)
# step 1: intra node reduce-scatter
intra_node_scatter(input, scatter_bufs, local_rank, stream)
# step 2: waits for all ranks to complete the scatter.
barrier_all_intra_node_atomic_cas_block[(1, )](local_rank, local_rank, local_world_size, sync_buf)
# step 3: perform reduction to get the result of the intra-node reduce-scatter.
ring_reduce(scatter_bufs[local_rank], output, local_rank, local_world_size, stream)
return output
Benckmark
def torch_rs(
input: torch.Tensor, # [M, N]
TP_GROUP,
):
M, N = input.shape
rs_output = torch.empty((M // WORLD_SIZE, N), dtype=input.dtype, device=input.device)
torch.distributed.reduce_scatter_tensor(rs_output, input, group=TP_GROUP)
return rs_output
if __name__ == "__main__":
# init
RANK = int(os.environ.get("RANK", 0))
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
TP_GROUP = triton_dist.utils.initialize_distributed()
torch.cuda.synchronize()
assert LOCAL_WORLD_SIZE == WORLD_SIZE, "runs on 1 node expected."
dtype = torch.bfloat16
M, N = 8192, 16384
input = torch.rand((M, N), dtype=dtype).cuda()
symm_scatter_bufs = nvshmem_create_tensors([M, N], dtype, RANK, LOCAL_WORLD_SIZE)
symm_sync_buf = nvshmem_create_tensor((LOCAL_WORLD_SIZE, ), dtype=torch.int32)
symm_sync_buf.fill_(0)
torch_output = torch_rs(input, TP_GROUP)
nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
dist_triton_output = reducer_scatter_intra_node(input, symm_scatter_bufs, symm_sync_buf, LOCAL_RANK,
LOCAL_WORLD_SIZE)
nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
torch.cuda.synchronize()
atol, rtol = 6e-2, 6e-2
torch.testing.assert_close(torch_output, dist_triton_output, atol=atol, rtol=rtol)
torch.cuda.synchronize()
print(f"RANK {LOCAL_RANK}: pass!")
nvshmem_free_tensor_sync(symm_sync_buf)
nvshmem_free_tensor_sync(symm_scatter_bufs[LOCAL_RANK])
nvshmem.core.finalize()
torch.distributed.destroy_process_group()