Intra-node ReduceScatter

In this tutorial, you will write a intra-node reduce-scatter operation.

In doing so, you will learn about:

  • use copy engine to communicate data within a node

  • How to write an intra-node reduce-scatter kernel.

# To run this tutorial
source ./scripts/setenv.sh
bash ./scripts/launch.sh tutorials/05-intra-node-reduce-scatter.py

Reduce Kernel

The difference between this kernel and a normal reduction operation lies in the accumulation order. In a normal reduction, the accumulation starts from the 0th row. This kernel starts the accumulation from the offset.

@triton.jit
def kernel_ring_reduce(
    c_ptr,  # [M, N]
    out_ptr,  # [M_per_split, N]
    # shape of matrix
    M_per_rank,
    N,
    begin_idx,
    num_splits: tl.constexpr,
    # reduce tile shape
    BLOCK_SIZE_M: tl.constexpr = 256,
    BLOCK_SIZE_N: tl.constexpr = 64,
):
    c_desc = tl.make_tensor_descriptor(
        c_ptr,
        shape=[M_per_rank * num_splits, N],
        strides=[N, 1],
        block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
    )
    output_desc = tl.make_tensor_descriptor(
        out_ptr,
        shape=[M_per_rank, N],
        strides=[N, 1],
        block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
    )

    pid = tl.program_id(axis=0)
    num_pid = tl.num_programs(axis=0)
    num_tiles_m = tl.cdiv(M_per_rank, BLOCK_SIZE_M)
    num_tiles_n = tl.cdiv(N, BLOCK_SIZE_N)
    total_tiles = num_tiles_m * num_tiles_n
    for tile_id in range(pid, total_tiles, num_pid):
        tile_id_m = tile_id // num_tiles_n
        tile_id_n = tile_id % num_tiles_n
        # accum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=out_ptr.dtype.element_ty)
        cur_rank = (begin_idx + 1) % num_splits
        accum = c_desc.load([tile_id_m * BLOCK_SIZE_M + cur_rank * M_per_rank, tile_id_n * BLOCK_SIZE_N])
        for i in range(1, num_splits):
            cur_rank = (i + begin_idx + 1) % num_splits
            data = c_desc.load([tile_id_m * BLOCK_SIZE_M + cur_rank * M_per_rank, tile_id_n * BLOCK_SIZE_N])
            accum += data

        output_desc.store([tile_id_m * BLOCK_SIZE_M, tile_id_n * BLOCK_SIZE_N], accum)
def ring_reduce(
    input,  # [M_per_node, N]
    output,  # [M_per_rank, N]
    begin_idx,
    num_splits,
    stream,
    num_sms=-1,
):
    # TMA descriptors require a global memory allocation
    def alloc_fn(size: int, alignment: int, stream: Optional[int]):
        return torch.empty(size, device="cuda", dtype=torch.int8)

    triton.set_allocator(alloc_fn)
    total_M, N = input.shape
    M_per_split = total_M // num_splits
    assert output.shape[0] == M_per_split and total_M % num_splits == 0
    if num_sms == -1:
        grid = lambda META: (triton.cdiv(M_per_split, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
        with torch.cuda.stream(stream):
            kernel_ring_reduce[grid](
                input,
                output,
                M_per_split,
                N,
                begin_idx,
                num_splits,
                BLOCK_SIZE_M=256,
                BLOCK_SIZE_N=64,
                num_warps=4,
            )
    else:
        grid = lambda META: (min(
            triton.cdiv(M_per_split, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), num_sms), )
        with torch.cuda.stream(stream):
            kernel_ring_reduce[grid](
                input,
                output,
                M_per_split,
                N,
                begin_idx,
                num_splits,
                BLOCK_SIZE_M=256,
                BLOCK_SIZE_N=128,
                num_warps=8,
            )

    return output

Scatter Kernel

We perform rank level swizzle in the scatter kernel. Each rank perform scatter start from the next rank of the current. In this way, the send/recv communication volume of each rank is balanced. For time start from 0 to local_world_size, the communication order between ranks:

  • time 0: 0->1, 1->2, 2->3, 3->0

  • time 1: 0->2, 1->3, 2->0, 3->1

  • time 2: 0->3, 1->0, 2->1, 3->2

  • time 3: 0->0, 1->1, 2->2, 3->3

def intra_node_scatter(input_intra_node, scatter_bufs_intra_node: List[torch.Tensor], local_rank, stream):
    M, N = input_intra_node.shape
    local_world_size = len(scatter_bufs_intra_node)
    M_per_rank = M // local_world_size

    # send input_intra_node[remote_rank * M_per_rank : (remote_rank + 1) * M_per_rank] on the current rank to
    # input_intra_node[rank * M_per_rank : (rank + 1)] on the remote rank.
    with torch.cuda.stream(stream):
        for i in range(0, local_world_size):
            remote_local_rank = (local_rank + i + 1) % local_world_size

            remote_buf = scatter_bufs_intra_node[remote_local_rank][local_rank * M_per_rank:(local_rank + 1) *
                                                                    M_per_rank, :]
            local_buf = input_intra_node[remote_local_rank * M_per_rank:(remote_local_rank + 1) * M_per_rank, :]
            # use copy engine to perform scatter(torch will use `cudamemcpy` to copy continuous data)
            remote_buf.copy_(local_buf)

Reduce-Scatter

@p2p_native_atomic_required
def reducer_scatter_intra_node(input, scatter_bufs, sync_buf, local_rank, local_world_size):

    stream = torch.cuda.current_stream()
    M, N = input.shape
    M_per_rank = M // local_world_size

    output = torch.empty((M_per_rank, N), dtype=dtype, device=input.device)
    # step 1: intra node reduce-scatter
    intra_node_scatter(input, scatter_bufs, local_rank, stream)

    # step 2: waits for all ranks to complete the scatter.
    barrier_all_intra_node_atomic_cas_block[(1, )](local_rank, local_rank, local_world_size, sync_buf)
    # step 3: perform reduction to get the result of the intra-node reduce-scatter.
    ring_reduce(scatter_bufs[local_rank], output, local_rank, local_world_size, stream)
    return output

Benckmark

def torch_rs(
    input: torch.Tensor,  # [M, N]
    TP_GROUP,
):
    M, N = input.shape
    rs_output = torch.empty((M // WORLD_SIZE, N), dtype=input.dtype, device=input.device)
    torch.distributed.reduce_scatter_tensor(rs_output, input, group=TP_GROUP)
    return rs_output


if __name__ == "__main__":
    # init
    RANK = int(os.environ.get("RANK", 0))
    LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
    WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
    LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
    TP_GROUP = triton_dist.utils.initialize_distributed()
    torch.cuda.synchronize()

    assert LOCAL_WORLD_SIZE == WORLD_SIZE, "runs on 1 node expected."

    dtype = torch.bfloat16
    M, N = 8192, 16384

    input = torch.rand((M, N), dtype=dtype).cuda()

    symm_scatter_bufs = nvshmem_create_tensors([M, N], dtype, RANK, LOCAL_WORLD_SIZE)
    symm_sync_buf = nvshmem_create_tensor((LOCAL_WORLD_SIZE, ), dtype=torch.int32)
    symm_sync_buf.fill_(0)

    torch_output = torch_rs(input, TP_GROUP)

    nvshmem_barrier_all_on_stream(torch.cuda.current_stream())

    dist_triton_output = reducer_scatter_intra_node(input, symm_scatter_bufs, symm_sync_buf, LOCAL_RANK,
                                                    LOCAL_WORLD_SIZE)

    nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
    torch.cuda.synchronize()

    atol, rtol = 6e-2, 6e-2
    torch.testing.assert_close(torch_output, dist_triton_output, atol=atol, rtol=rtol)
    torch.cuda.synchronize()
    print(f"RANK {LOCAL_RANK}: pass!")

    nvshmem_free_tensor_sync(symm_sync_buf)
    nvshmem_free_tensor_sync(symm_scatter_bufs[LOCAL_RANK])
    nvshmem.core.finalize()
    torch.distributed.destroy_process_group()