Overlapping AllGather GEMM

In this tutorial, you will write a simple Allgather GEMM fusion kernel using Triton-distributed.

In doing so, you will learn about:

  • Writing a GEMM kernel that consume the results of AllGather.

  • Optimizing the internode communication with 2D Allgather.

# To run this tutorial
source ./scripts/setenv.sh
bash ./scripts/launch.sh tutorials/07-overlapping-allgather-gemm.py

GEMM Kernel

In the AllGather kernel, we previously output signals that can be used to interact with computation kernels. For instance, in an AllGather GEMM example, we can achieve overlapping computation and communication. The GEMM itself doesn’t require extensive modifications, you just need to add a small amount of code compared to Triton’s default GEMM.

def _matmul_launch_metadata(grid, kernel, args):
    ret = {}
    M, N, K = args["M"], args["N"], args["K"]
    ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
    if "c_ptr" in args:
        bytes_per_elem = args["c_ptr"].element_size()
    else:
        bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
    ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K
    ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
    return ret


@triton_dist.jit(launch_metadata=_matmul_launch_metadata)
def kernel_consumer_gemm_persistent(a_ptr, b_ptr, c_ptr,  #
                                    M, N, K,  #
                                    rank: tl.constexpr, num_ranks: tl.constexpr, ready_ptr, comm_buf_ptr,
                                    BLOCK_SIZE_M: tl.constexpr,  #
                                    BLOCK_SIZE_N: tl.constexpr,  #
                                    BLOCK_SIZE_K: tl.constexpr,  #
                                    GROUP_SIZE_M: tl.constexpr,  #
                                    EPILOGUE_SUBTILE: tl.constexpr,  #
                                    NUM_SMS: tl.constexpr, ready_value: tl.constexpr = 1,
                                    local_world_size: tl.constexpr = 8):  #
    # Matmul using TMA and device-side descriptor creation
    dtype = c_ptr.dtype.element_ty
    start_pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
    num_tiles = num_pid_m * num_pid_n
    node_id = rank // local_world_size
    nnodes = num_ranks // local_world_size

    a_desc = tl.make_tensor_descriptor(
        a_ptr,
        shape=[M, K],
        strides=[K, 1],
        block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
    )
    b_desc = tl.make_tensor_descriptor(
        b_ptr,
        shape=[N, K],
        strides=[K, 1],
        block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
    )
    c_desc = tl.make_tensor_descriptor(
        c_ptr,
        shape=[M, N],
        strides=[N, 1],
        block_shape=[
            BLOCK_SIZE_M,
            BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2,
        ],
    )

    tiles_per_SM = num_tiles // NUM_SMS
    if start_pid < num_tiles % NUM_SMS:
        tiles_per_SM += 1

    tile_id = start_pid - NUM_SMS
    ki = -1

    pid_m = 0
    pid_n = 0
    offs_am = 0
    offs_bn = 0

    M_per_rank = M // num_ranks
    pid_ms_per_rank = tl.cdiv(M_per_rank, BLOCK_SIZE_M)

    num_pid_in_group = GROUP_SIZE_M * num_pid_n

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    for _ in range(0, k_tiles * tiles_per_SM):
        ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
        if ki == 0:
            tile_id += NUM_SMS
            group_id = tile_id // num_pid_in_group
            first_pid_m = group_id * GROUP_SIZE_M
            group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
            pid_m = first_pid_m + (tile_id % group_size_m)
            pid_n = (tile_id % num_pid_in_group) // group_size_m

            # swizzle m
            if nnodes == 1:
                alpha = 0
                beta = 0
                pid_m = (pid_m + ((((rank ^ alpha) + beta) % num_ranks) * pid_ms_per_rank)) % num_pid_m
            else:
                m_rank = pid_m // pid_ms_per_rank
                pid_m_intra_rank = pid_m - m_rank * pid_ms_per_rank
                m_node_id = m_rank // local_world_size
                m_local_rank = m_rank % local_world_size
                swizzle_m_node_id = (m_node_id + node_id) % nnodes
                swizzle_m_local_rank = (m_local_rank + rank) % local_world_size
                swizzle_m_rank = swizzle_m_node_id * local_world_size + swizzle_m_local_rank

                pid_m = swizzle_m_rank * pid_ms_per_rank + pid_m_intra_rank

            offs_am = pid_m * BLOCK_SIZE_M
            offs_bn = pid_n * BLOCK_SIZE_N

            rank_beg = offs_am // M_per_rank
            rank_end = (min(offs_am + BLOCK_SIZE_M, M) - 1) // M_per_rank
            # Each tile wait for the corresponding data to ready
            token = dl.wait(ready_ptr + rank_beg, rank_end - rank_beg + 1, "gpu", "acquire", waitValue=ready_value)
            a_desc = dl.consume_token(a_desc, token)

        offs_k = ki * BLOCK_SIZE_K
        # Iteration along k-dimension, and performing multiply.
        a = a_desc.load([offs_am, offs_k])
        b = b_desc.load([offs_bn, offs_k])
        accumulator = tl.dot(a, b.T, accumulator)

        if ki == k_tiles - 1:
            if EPILOGUE_SUBTILE:
                acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
                acc = tl.permute(acc, (0, 2, 1))
                acc0, acc1 = tl.split(acc)
                c0 = acc0.to(dtype)
                c_desc.store([offs_am, offs_bn], c0)
                c1 = acc1.to(dtype)
                c_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], c1)
            else:
                c = accumulator.to(dtype)
                c_desc.store([offs_am, offs_bn], c)

            accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

AllGather Kernel

To fully utilize the bandwidth, internode AllGather Kernel is composed of two parts considering the bandwidth gap between intra-node links and inter-node links.

def inter_node_allgather(local_tensor: torch.Tensor, ag_buffer: list[torch.Tensor], signal_buffer: list[torch.Tensor],
                     signal_target, rank, local_world_size, world_size, intranode_ag_stream, internode_ag_stream):
    local_rank = rank % local_world_size
    n_nodes = world_size // local_world_size
    M_per_rank, N = local_tensor.shape

    # Each rank sends the local_tensor to ranks of other nodes with the same local_rank
    # Assuming there are 2 nodes, each with 4 workers
    # 0-th local tensor ([0] -> [4]), 4-th local tensor ([4] -> [0])
    # 1-th local tensor ([1] -> [5]), 5-th local tensor ([5] -> [1])
    # 2-th local tensor ([2] -> [6]), 6-th local tensor ([6] -> [2])
    # 3-th local tensor ([3] -> [7]), 7-th local tensor ([7] -> [3])
    with torch.cuda.stream(internode_ag_stream):
        grid = lambda META: (int(n_nodes - 1), )
        nvshmem_device_producer_p2p_put_block_kernel[grid](
            ag_buffer[local_rank], signal_buffer[local_rank], M_per_rank * N, local_tensor.element_size(),
            signal_target, rank, local_world_size, world_size, num_warps=32,  # each sm launches 1024 threads
        )

    # Each rank sends the local_tensor and the received internode tensors to intranode ranks.
    # 0-th and 4-th local tensors ([0]->[1,2,3])
    # 1-th and 5-th local tensors ([1]->[0,2,3])
    # 2-th and 6-th local tensors ([2]->[0,1,3])
    # 3-th and 7-th local tensors ([3]->[0,1,2])
    # 0-th and 4-th local tensors ([4]->[5,6,7])
    # 1-th and 5-th local tensors ([5]->[4,6,7])
    # 2-th and 6-th local tensors ([6]->[4,5,7])
    # 3-th and 7-th local tensors ([7]->[4,5,6])
    with torch.cuda.stream(intranode_ag_stream):
        cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank,
                                        local_world_size, world_size, intranode_ag_stream)

    intranode_ag_stream.wait_stream(internode_ag_stream)

Let’s declare a function to perform internode communication.

@triton_dist.jit
def nvshmem_device_producer_p2p_put_block_kernel(
    ag_buffer_ptr,  # *Pointer* to allgather output vector. The rank-th index has been loaded with local tensor
    signal_buffer_ptr,  # *Pointer* to signal barrier.
    elem_per_rank,
    size_per_elem,
    signal_target,
    rank,
    local_world_size,
    world_size,
):
    pid = tl.program_id(axis=0)
    num_pid = tl.num_programs(axis=0)

    n_nodes = world_size // local_world_size
    local_rank = rank % local_world_size
    node_rank = rank // local_world_size

    for i in range(pid, n_nodes - 1, num_pid):
        # Each SM is assigned to one peer.
        # Peer id is caculated based on pid and local_rank.
        peer = local_rank + (node_rank + i + 1) % n_nodes * local_world_size
        # We use putmem_signal_block to send data and notify the peer.
        # Since this is the allgather operation, the offsets of both src and dst tensor are both *rank*.
        libshmem_device.putmem_signal_block(
            ag_buffer_ptr + rank * elem_per_rank,
            ag_buffer_ptr + rank * elem_per_rank,
            elem_per_rank * size_per_elem,
            signal_buffer_ptr + rank,
            signal_target,
            libshmem_device.NVSHMEM_SIGNAL_SET,
            peer,
        )

Let’s also declare a function to perform intranode communication.

def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank,
                                  local_world_size, world_size, intranode_ag_stream):
    local_rank = rank % local_world_size
    n_nodes = world_size // local_world_size
    node_rank = rank // local_world_size

    for i in range(1, local_world_size):
        segment = rank * M_per_rank * N
        local_dst_rank = (local_rank + local_world_size - i) % local_world_size
        src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
        dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
        # Using copy engine to perform intranode transmission
        # Sending rank-th local tensor to other ranks inside the node.
        (err, ) = cudart.cudaMemcpyAsync(
            dst_ptr,
            src_ptr,
            M_per_rank * N * local_tensor.element_size(),
            cudart.cudaMemcpyKind.cudaMemcpyDefault,
            intranode_ag_stream.cuda_stream,
        )
        # Notify the peer that the transmission is done.
        set_signal(signal_buffer[local_dst_rank][rank].data_ptr(), signal_target, intranode_ag_stream, True)

    for i in range(1, n_nodes):
        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
        recv_segment = recv_rank * M_per_rank * N
        # Waiting for the internode data ready
        wait_eq(signal_buffer[local_rank][recv_rank].data_ptr(), signal_target, intranode_ag_stream, True)
        src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
        for j in range(1, local_world_size):
            local_dst_rank = (local_rank + local_world_size - j) % local_world_size
            dst_ptr = ag_buffer[local_dst_rank].data_ptr() + recv_segment * local_tensor.element_size()
            # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
            (err, ) = cudart.cudaMemcpyAsync(
                dst_ptr,
                src_ptr,
                M_per_rank * N * local_tensor.element_size(),
                cudart.cudaMemcpyKind.cudaMemcpyDefault,
                intranode_ag_stream.cuda_stream,
            )
            # Notify the peer that the transmission is done.
            set_signal(signal_buffer[local_dst_rank][recv_rank].data_ptr(), signal_target, intranode_ag_stream, True)

AllGather GEMM Kernel

Now we combine all the kernels here.

def ag_gemm_persistent_op(a, b, c, rank, num_ranks, workspace_tensors, barrier_tensors, comm_buf, ag_stream=None,
                      internode_ag_stream=None, BLOCK_M=128, BLOCK_N=256, BLOCK_K=64, stages=3, local_world_size=8,
                      signal_target=1):
    assert a.shape[1] == b.shape[1], "Incompatible dimensions"
    assert a.dtype == b.dtype, "Incompatible dtypes"

    M_per_rank, K = a.shape
    M = M_per_rank * num_ranks
    N_per_rank, K = b.shape

    local_rank = rank % local_world_size
    n_nodes = num_ranks // local_world_size
    num_ag_sms = n_nodes - 1  # only use n_node-1 SMs for internode communication
    num_gemm_sms = torch.cuda.get_device_properties("cuda").multi_processor_count - num_ag_sms

    ag_stream = torch.cuda.Stream() if ag_stream is None else ag_stream
    current_stream = torch.cuda.current_stream()
    ag_stream.wait_stream(current_stream)

    inter_node_allgather(a, workspace_tensors, barrier_tensors, signal_target, rank, local_world_size, num_ranks,
                        ag_stream, internode_ag_stream)

    compiled = None

    def alloc_fn(size: int, alignment: int, stream: Optional[int]):
        return torch.empty(size, device="cuda", dtype=torch.int8)

    triton.set_allocator(alloc_fn)
    grid = lambda META: (min(
        num_gemm_sms,
        triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N_per_rank, META["BLOCK_SIZE_N"]),
    ), )
    compiled = kernel_consumer_gemm_persistent[grid](
        workspace_tensors[local_rank][:M],
        b,
        c,  #
        M,
        N_per_rank,
        K,  #
        rank,
        num_ranks,
        barrier_tensors[local_rank],
        comm_buf,
        BLOCK_M,
        BLOCK_N,
        BLOCK_K,
        8,
        False,
        NUM_SMS=num_gemm_sms,
        ready_value=signal_target,
        num_stages=stages,
        num_warps=8,
    )

    current_stream.wait_stream(internode_ag_stream)
    current_stream.wait_stream(ag_stream)
    return compiled

Benchmark

def torch_ag_gemm(
    pg: torch.distributed.ProcessGroup,
    local_input: torch.Tensor,
    local_weight: torch.Tensor,
    ag_out: torch.Tensor,
):
    torch.distributed.all_gather_into_tensor(ag_out, local_input, pg)
    ag_gemm_output = torch.matmul(ag_out, local_weight)
    return ag_gemm_output


if __name__ == "__main__":
    WORLD_SIZE = int(os.getenv("WORLD_SIZE", "-1"))
    LOCAL_WORLD_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "-1"))

    if WORLD_SIZE == LOCAL_WORLD_SIZE:
        print("Skip the test because this should be performed with 2 nodes or higher")
        import sys
        sys.exit()

    if torch.cuda.get_device_capability()[0] < 9:
        print("Skip the test because the device is not sm90 or higher")
        import sys
        sys.exit()

    TP_GROUP = initialize_distributed()
    rank = TP_GROUP.rank()

    M = 8192
    N = 49152
    K = 12288
    config = {"BM": 128, "BN": 256, "BK": 64, "stage": 3}
    dtype = torch.float16

    assert M % WORLD_SIZE == 0
    assert N % WORLD_SIZE == 0
    M_per_rank = M // WORLD_SIZE
    N_per_rank = N // WORLD_SIZE

    A = torch.randn([M_per_rank, K], dtype=dtype, device="cuda")
    B = torch.randn([N_per_rank, K], dtype=dtype, device="cuda")

    ag_buffer = torch.empty([M, K], dtype=dtype, device="cuda")
    golden = torch_ag_gemm(TP_GROUP, A, B.T, ag_buffer)

    # We can use a context to wrap all the tensors used at runtime.
    # We rely on NVSHMEM to allocate the symmetric memory for communication
    # In practice, the following parts are encapsulated in ag_gemm_inter_node() of triton_dist.kernels.nvidia.allgather_gemm.py

    C = torch.empty([M, N_per_rank], dtype=dtype, device="cuda")
    ctx = create_ag_gemm_context(A, B, rank, WORLD_SIZE, max_M=M, BLOCK_M=config["BM"], BLOCK_N=config["BN"],
                                BLOCK_K=config["BK"], stages=config["stage"])
    ctx.symm_barrier.fill_(0)
    nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
    # copy local data to the ctx
    ctx.symm_workspace[rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
    set_signal(ctx.symm_barrier[rank].data_ptr(), 1, torch.cuda.current_stream(), True)

    # launch the ag_gemm kernel
    ag_gemm_persistent_op(A, B, C, ctx.rank, ctx.num_ranks, ctx.symm_workspaces, ctx.symm_barriers, ctx.symm_comm_buf,
                        ag_stream=ctx.ag_intranode_stream, internode_ag_stream=ctx.ag_internode_stream,
                        local_world_size=LOCAL_WORLD_SIZE, signal_target=1)

    assert torch.allclose(golden, C, atol=1e-3, rtol=1e-3)
    print("Pass!")

    ctx.finalize()
    nvshmem.core.finalize()
    torch.distributed.destroy_process_group()